mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Merge branch 'develop' into whole_k_prefetch_n0loop
This commit is contained in:
@@ -1,30 +0,0 @@
|
||||
resources:
|
||||
repositories:
|
||||
- repository: pipelines_repo
|
||||
type: github
|
||||
endpoint: ROCm
|
||||
name: ROCm/ROCm
|
||||
|
||||
variables:
|
||||
- group: common
|
||||
- template: /.azuredevops/variables-global.yml@pipelines_repo
|
||||
|
||||
trigger:
|
||||
batch: true
|
||||
branches:
|
||||
include:
|
||||
- develop
|
||||
- amd-develop
|
||||
paths:
|
||||
exclude:
|
||||
- .github
|
||||
- docs
|
||||
- '.*.y*ml'
|
||||
- '*.md'
|
||||
- Jenkinsfile
|
||||
- LICENSE
|
||||
|
||||
pr: none
|
||||
|
||||
jobs:
|
||||
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
|
||||
12
.github/CODEOWNERS
vendored
12
.github/CODEOWNERS
vendored
@@ -1,8 +1,8 @@
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
|
||||
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron
|
||||
# Documentation files
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
|
||||
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
|
||||
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
|
||||
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
|
||||
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
|
||||
# Header directory for Doxygen documentation
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
|
||||
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron
|
||||
|
||||
143
.github/scripts/therock_configure_ci.py
vendored
143
.github/scripts/therock_configure_ci.py
vendored
@@ -1,143 +0,0 @@
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Iterable, Optional, Mapping
|
||||
|
||||
|
||||
def gha_set_output(vars: Mapping[str, str | Path]):
|
||||
"""Sets values in a step's output parameters.
|
||||
|
||||
This appends to the file located at the $GITHUB_OUTPUT environment variable.
|
||||
|
||||
See
|
||||
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
|
||||
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
|
||||
"""
|
||||
print(f"Setting github output:\n{vars}")
|
||||
|
||||
step_output_file = os.getenv("GITHUB_OUTPUT")
|
||||
if not step_output_file:
|
||||
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
|
||||
return
|
||||
|
||||
with open(step_output_file, "a") as f:
|
||||
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
|
||||
|
||||
|
||||
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
|
||||
"""Returns the paths of modified files relative to the base reference."""
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "diff", "--name-only", base_ref],
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
).stdout.splitlines()
|
||||
except TimeoutError:
|
||||
print(
|
||||
"Computing modified files timed out. Not using PR diff to determine"
|
||||
" jobs to run.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
GITHUB_WORKFLOWS_CI_PATTERNS = [
|
||||
"therock*",
|
||||
]
|
||||
|
||||
|
||||
def is_path_workflow_file_related_to_ci(path: str) -> bool:
|
||||
return any(
|
||||
fnmatch.fnmatch(path, ".github/workflows/" + pattern)
|
||||
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
|
||||
) or any(
|
||||
fnmatch.fnmatch(path, ".github/scripts/" + pattern)
|
||||
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
|
||||
)
|
||||
|
||||
|
||||
def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool:
|
||||
if paths is None:
|
||||
return False
|
||||
return any(is_path_workflow_file_related_to_ci(p) for p in paths)
|
||||
|
||||
|
||||
# Paths matching any of these patterns are considered to have no influence over
|
||||
# build or test workflows so any related jobs can be skipped if all paths
|
||||
# modified by a commit/PR match a pattern in this list.
|
||||
SKIPPABLE_PATH_PATTERNS = [
|
||||
"docs/*",
|
||||
"*.gitignore",
|
||||
"*.md",
|
||||
"*.pre-commit-config.*",
|
||||
"*LICENSE",
|
||||
"Jenkinsfile",
|
||||
".github/ISSUE_TEMPLATE/*",
|
||||
".github/CODEOWNERS",
|
||||
".github/*.md",
|
||||
".github/dependabot.yml",
|
||||
]
|
||||
|
||||
|
||||
def is_path_skippable(path: str) -> bool:
|
||||
"""Determines if a given relative path to a file matches any skippable patterns."""
|
||||
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
|
||||
|
||||
|
||||
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if at least one path is not in the skippable set."""
|
||||
if paths is None:
|
||||
return False
|
||||
return any(not is_path_skippable(p) for p in paths)
|
||||
|
||||
|
||||
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if CI workflows should run given a list of modified paths."""
|
||||
|
||||
if paths is None:
|
||||
print("No files were modified, skipping TheRock CI jobs")
|
||||
return False
|
||||
|
||||
paths_set = set(paths)
|
||||
github_workflows_paths = set(
|
||||
[p for p in paths if p.startswith(".github/workflows")]
|
||||
)
|
||||
other_paths = paths_set - github_workflows_paths
|
||||
|
||||
related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths)
|
||||
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
|
||||
|
||||
print("should_ci_run_given_modified_paths findings:")
|
||||
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
|
||||
|
||||
if related_to_ci:
|
||||
print("Enabling build jobs since a related workflow file was modified")
|
||||
return True
|
||||
elif contains_other_non_skippable_files:
|
||||
print("Enabling TheRock CI jobs since a non-skippable path was modified")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def main(args):
|
||||
base_ref = args.get("base_ref")
|
||||
modified_paths = get_modified_paths(base_ref)
|
||||
print("modified_paths (max 200):", modified_paths[:200])
|
||||
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
|
||||
output = {"enable_therock_ci": json.dumps(enable_jobs)}
|
||||
gha_set_output(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {}
|
||||
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
|
||||
main(args)
|
||||
16
.github/workflows/pre-commit.yml
vendored
16
.github/workflows/pre-commit.yml
vendored
@@ -1,16 +0,0 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [develop]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
145
.github/workflows/therock-ci-linux.yml
vendored
145
.github/workflows/therock-ci-linux.yml
vendored
@@ -1,145 +0,0 @@
|
||||
name: TheRock CI Linux
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
cmake_options:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
therock-build-linux:
|
||||
name: Build Linux Packages
|
||||
runs-on: azure-linux-scale-rocm
|
||||
permissions:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b
|
||||
options: -v /runner/config:/home/awsconfig/
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
TEATIME_FORCE_INTERACTIVE: 0
|
||||
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
|
||||
CACHE_DIR: ${{ github.workspace }}/.container-cache
|
||||
# The ccache.conf will be written by setup_ccache.py before this gets used.
|
||||
CCACHE_CONFIGPATH: ${{ github.workspace }}/.ccache/ccache.conf
|
||||
steps:
|
||||
- name: "Checking out repository for rocm-libraries"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/rocm-libraries"
|
||||
|
||||
- name: Pull DVC files for rocm-libraries # LOGNAME details here https://github.com/ROCm/rocm-libraries/pull/1617
|
||||
run: |
|
||||
if command -v dvc &> /dev/null; then
|
||||
echo "dvc detected"
|
||||
else
|
||||
echo "Warning, dvc not detected!"
|
||||
fi
|
||||
LOGNAME=github-runner dvc pull -v
|
||||
|
||||
- name: Checkout composable_kernel repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
path: "composable_kernel"
|
||||
|
||||
- name: Checkout TheRock repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
path: "TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
|
||||
- name: Setup ccache
|
||||
run: |
|
||||
./TheRock/build_tools/setup_ccache.py \
|
||||
--config-preset "github-oss-presubmit" \
|
||||
--dir "$(dirname $CCACHE_CONFIGPATH)" \
|
||||
--local-path "$CACHE_DIR/ccache"
|
||||
echo "namespace = ext_composable_kernel" >> $CCACHE_CONFIGPATH
|
||||
echo "[*] ccache_config contents:"
|
||||
cat $CCACHE_CONFIGPATH
|
||||
|
||||
- name: Runner Health Settings
|
||||
run: |
|
||||
./TheRock/build_tools/health_status.py
|
||||
|
||||
- name: Fetch sources
|
||||
run: |
|
||||
./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks
|
||||
|
||||
- name: Patch rocm-libraries
|
||||
run: |
|
||||
git config --global --add safe.directory '*'
|
||||
# Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo
|
||||
rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch
|
||||
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
|
||||
|
||||
- name: Install python deps
|
||||
run: |
|
||||
pip install -r TheRock/requirements.txt
|
||||
pip freeze
|
||||
|
||||
- name: Configure Projects
|
||||
env:
|
||||
amdgpu_families: ${{ env.AMDGPU_FAMILIES }}
|
||||
package_version: ADHOCBUILD
|
||||
extra_cmake_options: ${{ inputs.cmake_options }}
|
||||
BUILD_DIR: build
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/build_configure.py
|
||||
|
||||
- name: Build TheRock
|
||||
run: cmake --build TheRock/build
|
||||
|
||||
- name: Build therock-archives
|
||||
run: cmake --build TheRock/build --target therock-archives
|
||||
|
||||
- name: Report
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
echo "Full SDK du:"
|
||||
echo "------------"
|
||||
du -h -d 1 TheRock/build/dist/rocm
|
||||
echo "Artifact Archives:"
|
||||
echo "------------------"
|
||||
ls -lh TheRock/build/artifacts/*.tar.xz
|
||||
echo "Artifacts:"
|
||||
echo "----------"
|
||||
du -h -d 1 TheRock/build/artifacts
|
||||
echo "CCache Stats:"
|
||||
echo "-------------"
|
||||
ccache -s -v
|
||||
tail -v -n +1 .ccache/compiler_check_cache/* > TheRock/build/logs/ccache_compiler_check_cache.log
|
||||
|
||||
- name: Configure AWS Credentials for non-forked repos
|
||||
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
|
||||
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
|
||||
with:
|
||||
aws-region: us-east-2
|
||||
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
|
||||
|
||||
- name: Post Build Upload
|
||||
if: always()
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/post_build_upload.py \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--artifact-group ${{ env.AMDGPU_FAMILIES }} \
|
||||
--build-dir TheRock/build \
|
||||
--upload
|
||||
|
||||
therock-test-linux:
|
||||
name: "Test"
|
||||
needs: [therock-build-linux]
|
||||
uses: ./.github/workflows/therock-test-packages.yml
|
||||
with:
|
||||
project_to_test: "miopen"
|
||||
amdgpu_families: ${{ inputs.amdgpu_families }}
|
||||
test_runs_on: ${{ inputs.test_runs_on }}
|
||||
platform: "linux"
|
||||
88
.github/workflows/therock-ci.yml
vendored
88
.github/workflows/therock-ci.yml
vendored
@@ -1,88 +0,0 @@
|
||||
name: TheRock CI for composable_kernel
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
branches:
|
||||
- mainline
|
||||
- release/*
|
||||
- release-staging/*
|
||||
- develop
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
# A PR number if a pull request and otherwise the commit hash. This cancels
|
||||
# queued and in-progress runs for the same PR (presubmit) or commit
|
||||
# (postsubmit). The workflow name is prepended to avoid conflicts between
|
||||
# different workflows.
|
||||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
# The commit being checked out is the merge commit for a PR. Its first
|
||||
# parent will be the tip of the base branch.
|
||||
BASE_REF: HEAD^
|
||||
outputs:
|
||||
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
# We need the parent commit to do a diff
|
||||
fetch-depth: 2
|
||||
|
||||
- name: "Configuring CI options"
|
||||
id: configure
|
||||
run: python .github/scripts/therock_configure_ci.py
|
||||
|
||||
therock-ci-linux:
|
||||
name: TheRock CI Linux
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
uses: ./.github/workflows/therock-ci-linux.yml
|
||||
secrets: inherit
|
||||
with:
|
||||
cmake_options: >-
|
||||
-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON
|
||||
-DTHEROCK_ENABLE_MIOPEN=ON
|
||||
-DTHEROCK_ENABLE_ALL=OFF
|
||||
-DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON
|
||||
-DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel
|
||||
-DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON
|
||||
-DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../
|
||||
amdgpu_families: "gfx94X-dcgpu"
|
||||
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
|
||||
|
||||
therock_ci_summary:
|
||||
name: TheRock CI Summary
|
||||
if: always()
|
||||
needs:
|
||||
- setup
|
||||
- therock-ci-linux
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Output failed jobs
|
||||
run: |
|
||||
echo '${{ toJson(needs) }}'
|
||||
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
|
||||
| jq --raw-output \
|
||||
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
|
||||
)"
|
||||
if [[ "${FAILED_JOBS}" != "" ]]; then
|
||||
echo "The following jobs failed: ${FAILED_JOBS}"
|
||||
exit 1
|
||||
fi
|
||||
72
.github/workflows/therock-test-component.yml
vendored
72
.github/workflows/therock-test-component.yml
vendored
@@ -1,72 +0,0 @@
|
||||
name: Test component
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
artifact_run_id:
|
||||
type: string
|
||||
default: ""
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
platform:
|
||||
type: string
|
||||
component:
|
||||
type: string
|
||||
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
test_component:
|
||||
name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})'
|
||||
runs-on: ${{ inputs.test_runs_on }}
|
||||
container:
|
||||
image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }}
|
||||
options: --ipc host
|
||||
--group-add video
|
||||
--device /dev/kfd
|
||||
--device /dev/dri
|
||||
--group-add 110
|
||||
--env-file /etc/podinfo/gha-gpu-isolation-settings
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# The shard array is based on "total_shards" from "fetch_test_configurations.py"
|
||||
# The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards)
|
||||
shard: ${{ fromJSON(inputs.component).shard_arr }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
env:
|
||||
VENV_DIR: ${{ github.workspace }}/.venv
|
||||
ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}"
|
||||
OUTPUT_ARTIFACTS_DIR: "./build"
|
||||
THEROCK_BIN_DIR: "./build/bin"
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
|
||||
- name: Run setup test environment workflow
|
||||
uses: './.github/actions/setup_test_environment'
|
||||
with:
|
||||
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
|
||||
ARTIFACT_GROUP: ${{ inputs.amdgpu_families }}
|
||||
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
|
||||
VENV_DIR: ${{ env.VENV_DIR }}
|
||||
FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }}
|
||||
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
|
||||
|
||||
- name: Test
|
||||
timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }}
|
||||
env:
|
||||
SHARD_INDEX: ${{ matrix.shard }}
|
||||
TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }}
|
||||
run: |
|
||||
${{ fromJSON(inputs.component).test_script }}
|
||||
54
.github/workflows/therock-test-packages.yml
vendored
54
.github/workflows/therock-test-packages.yml
vendored
@@ -1,54 +0,0 @@
|
||||
name: TheRock Test Packages
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
project_to_test:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
platform:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
configure_test_matrix:
|
||||
name: "Configure test matrix"
|
||||
runs-on: ubuntu-24.04
|
||||
if: ${{ inputs.test_runs_on != '' }}
|
||||
outputs:
|
||||
components: ${{ steps.configure.outputs.components }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
|
||||
|
||||
- name: "Configuring CI options"
|
||||
env:
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
project_to_test: ${{ inputs.project_to_test }}
|
||||
id: configure
|
||||
run: python ./build_tools/github_actions/fetch_test_configurations.py
|
||||
|
||||
test_components:
|
||||
name: 'Test ${{ matrix.components.job_name }}'
|
||||
needs: [configure_test_matrix]
|
||||
# skip tests if no test matrix to run
|
||||
if: ${{ needs.configure_test_matrix.outputs.components != '[]' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }}
|
||||
uses: './.github/workflows/therock-test-component.yml'
|
||||
with:
|
||||
artifact_run_id: ${{ github.run_id }}
|
||||
amdgpu_families: ${{ inputs.amdgpu_families }}
|
||||
test_runs_on: ${{ inputs.test_runs_on }}
|
||||
platform: ${{ inputs.platform }}
|
||||
component: ${{ toJSON(matrix.components) }}
|
||||
29
.gitignore
vendored
29
.gitignore
vendored
@@ -36,6 +36,9 @@ tags
|
||||
# Editors
|
||||
.vscode
|
||||
|
||||
# CMake formatting configuration (local)
|
||||
.cmake-format.yaml
|
||||
|
||||
# Cline
|
||||
.cline*
|
||||
|
||||
@@ -78,9 +81,35 @@ CMakeUserPresets.json
|
||||
# Python cache
|
||||
__pycache__/
|
||||
|
||||
# Cache directories
|
||||
.cache/
|
||||
.ck_tile_cache/
|
||||
ck_tile_cache/
|
||||
**/kernel_cache/
|
||||
**/.kernel_cache/
|
||||
|
||||
# Dispatcher kernel cache (user-generated, can be large)
|
||||
dispatcher/**/kernel_cache/
|
||||
dispatcher/**/.kernel_cache/
|
||||
dispatcher/**/cached_kernels/
|
||||
dispatcher/**/*.hsaco
|
||||
dispatcher/**/*.co
|
||||
|
||||
# Dispatcher generated JSON exports
|
||||
dispatcher/**/*_kernels.json
|
||||
dispatcher/**/dispatcher_kernels.json
|
||||
|
||||
# Generated test data
|
||||
test_data/*
|
||||
!test_data/*.py
|
||||
!test_data/*.sh
|
||||
!test_data/requirements.txt
|
||||
|
||||
# Exceptions to build* patterns above
|
||||
# The experimental/builder directory should be tracked despite matching build*
|
||||
!experimental/builder
|
||||
!experimental/builder/**
|
||||
experimental/grouped_convolution_tile_instances/instances/*
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.in
|
||||
!experimental/grouped_convolution_tile_instances/instances/*.inc
|
||||
experimental/grouped_convolution_tile_instances/*.inc
|
||||
|
||||
@@ -20,21 +20,21 @@ repos:
|
||||
)$
|
||||
- repo: local
|
||||
hooks:
|
||||
# - id: copyright-year-checker
|
||||
# name: copyright-year-checker
|
||||
# entry: script/check_copyright_year.sh
|
||||
# verbose: false
|
||||
# language: script
|
||||
# types: [c++]
|
||||
- id: copyright-header-checker
|
||||
name: Check copyright headers
|
||||
entry: projects/composablekernel/script/check_copyright_year.sh
|
||||
verbose: false
|
||||
language: script
|
||||
types_or: [c++, python, shell, cmake]
|
||||
- id: remove-exec-bit
|
||||
name: Remove executable bit from non-executable files
|
||||
entry: script/remove_exec_bit.sh
|
||||
entry: projects/composablekernel/script/remove_exec_bit.sh
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
- id: remod-ck-tile
|
||||
name: Run ck_tile remod.py
|
||||
entry: python script/remod_for_ck_tile.py
|
||||
entry: python projects/composablekernel/script/remod_for_ck_tile.py
|
||||
language: python
|
||||
files: '^(include|example)/ck_tile/.*$'
|
||||
additional_dependencies:
|
||||
|
||||
51
CHANGELOG.md
51
CHANGELOG.md
@@ -2,32 +2,58 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## (Unreleased) Composable Kernel 1.3.0
|
||||
|
||||
### Added
|
||||
* Added preshuffleB support for abquant mode in blockscale GEMM.
|
||||
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
|
||||
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
|
||||
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
|
||||
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
|
||||
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
|
||||
* Added FP8 KV cache support for FMHA batch prefill.
|
||||
* Added support for gfx1153 target.
|
||||
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
|
||||
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
|
||||
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
|
||||
* Added FP8 block scale quantization for FMHA forward kernel.
|
||||
* Added gfx11 support for FMHA.
|
||||
|
||||
### Changed
|
||||
|
||||
### Upcoming changes
|
||||
|
||||
## Composable Kernel 1.2.0 for ROCm 7.2.0
|
||||
|
||||
### Added
|
||||
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
|
||||
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
|
||||
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
|
||||
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM
|
||||
* Added a compute async pipeline in the CK TILE universal GEMM on gfx950
|
||||
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
|
||||
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
|
||||
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM.
|
||||
* Added a compute async pipeline in the CK Tile universal GEMM on gfx950.
|
||||
* Added support for B Tensor type `pk_int4_t` in the CK Tile weight preshuffle GEMM.
|
||||
* Added the new api to load different memory sizes to SGPR.
|
||||
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
|
||||
* Added support for B Tensor preshuffle in CK Tile grouped GEMM.
|
||||
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
|
||||
* Added support for grouped_gemm kernels to perform multi_d elementwise operation.
|
||||
* Added support for Multiple ABD GEMM
|
||||
* Added support for grouped GEMM kernels to perform Multi D elementwise operation.
|
||||
* Added support for multiple ABD GEMM.
|
||||
* Added benchmarking support for tile engine GEMM Multi D.
|
||||
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
|
||||
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
|
||||
* Added support for f32 to FMHA (fwd/bwd).
|
||||
* Added tensor-wise quantization for CK_TILE GEMM.
|
||||
* Added block scaling support in CK Tile GEMM, allowing flexible use of quantization matrices from either A or B operands.
|
||||
* Added the row-wise column-wise quantization for CK Tile GEMM and CK Tile grouped GEMM.
|
||||
* Added support for f32 to FMHA (forward and backward).
|
||||
* Added tensor-wise quantization for CK Tile GEMM.
|
||||
* Added support for batched contraction kernel.
|
||||
* Added WMMA (gfx12) support for FMHA.
|
||||
* Added pooling kernel in CK_TILE
|
||||
* Added top-k sigmoid kernel in CK_TILE
|
||||
* Added the blockscale 2D support for CK_TILE GEMM.
|
||||
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
|
||||
* Added reduce and multi reduction kernels
|
||||
|
||||
### Changed
|
||||
|
||||
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
|
||||
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK Tile (#2594)
|
||||
* Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures.
|
||||
* FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time.
|
||||
|
||||
@@ -73,11 +99,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added CK Tile Epilogue Chainer framework for composable epilogue sequences in GEMM operations
|
||||
|
||||
### Optimized
|
||||
|
||||
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
|
||||
* Added Vectorize Transpose optimization for CK Tile
|
||||
* Added Vectorize Transpose optimization for CK Tile
|
||||
* Added the asynchronous copy for gfx950
|
||||
|
||||
### Changed
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
cmake_minimum_required(VERSION 3.21)
|
||||
if(POLICY CMP0140)
|
||||
# policies CMP0140 not known to CMake until 3.25
|
||||
cmake_policy(SET CMP0140 NEW)
|
||||
@@ -28,21 +31,31 @@ endif()
|
||||
# Default installation path
|
||||
if(NOT WIN32)
|
||||
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
|
||||
else()
|
||||
set(CMAKE_INSTALL_PREFIX "C:/dist/TheRock" CACHE PATH "")
|
||||
endif()
|
||||
|
||||
# Enable ASAN when THEROCK_SANITIZER is set to ASAN or HOST_ASAN
|
||||
if(THEROCK_SANITIZER STREQUAL "ASAN" OR THEROCK_SANITIZER STREQUAL "HOST_ASAN")
|
||||
set(ENABLE_ASAN_PACKAGING ON)
|
||||
message(STATUS "Enabling ASAN for Composable Kernel (THEROCK_SANITIZER=${THEROCK_SANITIZER})")
|
||||
endif()
|
||||
|
||||
set(version 1.2.0)
|
||||
# Check support for CUDA/HIP in Cmake
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX)
|
||||
include(CTest)
|
||||
|
||||
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
|
||||
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include)
|
||||
include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include)
|
||||
endif()
|
||||
|
||||
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
|
||||
@@ -87,6 +100,10 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "tf32")
|
||||
# definition will be added based on the GPU target in the following section
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
@@ -101,6 +118,7 @@ else()
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
@@ -113,6 +131,9 @@ add_compile_options(-Wno-pass-failed)
|
||||
add_compile_options(-Wno-switch-default)
|
||||
add_compile_options(-Wno-unique-object-duplication)
|
||||
|
||||
# Increase the number of max elements in fold expressions
|
||||
add_compile_options(-fbracket-depth=1024)
|
||||
|
||||
# add -Og -gdwarf64 for debug builds
|
||||
add_compile_options(
|
||||
"$<$<CONFIG:Debug>:-Og>"
|
||||
@@ -152,7 +173,13 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI
|
||||
configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h)
|
||||
|
||||
set(ROCM_SYMLINK_LIBS OFF)
|
||||
find_package(ROCM REQUIRED PATHS /opt/rocm)
|
||||
|
||||
if (WIN32)
|
||||
find_package(ROCmCMakeBuildTools REQUIRED PATHS C:/dist/TheRock)
|
||||
set(HIP_PLATFORM "amd" CACHE STRING "HIP platform")
|
||||
else()
|
||||
find_package(ROCM REQUIRED PATHS /opt/rocm)
|
||||
endif()
|
||||
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMPackageConfigHelpers)
|
||||
@@ -179,7 +206,10 @@ if(GPU_TARGETS)
|
||||
else()
|
||||
set(USER_GPU_TARGETS 0)
|
||||
endif()
|
||||
|
||||
find_package(hip REQUIRED)
|
||||
enable_language(HIP)
|
||||
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
|
||||
@@ -204,7 +234,7 @@ if(NOT ENABLE_ASAN_PACKAGING)
|
||||
endif()
|
||||
else()
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+")
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+;gfx950:xnack+")
|
||||
endif()
|
||||
|
||||
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
|
||||
@@ -229,16 +259,21 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}"
|
||||
# Cache SUPPORTED_GPU_TARGETS for debug
|
||||
set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets")
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL instances")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL FP8 gemms on gfx950")
|
||||
add_definitions(-DCK_USE_GFX950)
|
||||
set(CK_USE_GFX950 "ON")
|
||||
endif()
|
||||
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
@@ -247,7 +282,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
|
||||
add_definitions(-DCK_GFX1030_SUPPORT)
|
||||
endif()
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
@@ -257,7 +292,7 @@ endif()
|
||||
# define the macro with the current value (0 or 1)
|
||||
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA)
|
||||
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_WMMA_FP8)
|
||||
set(CK_USE_WMMA_FP8 "ON")
|
||||
@@ -277,6 +312,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
set(CK_GFX950_SUPPORT "ON")
|
||||
endif()
|
||||
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
|
||||
add_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
else()
|
||||
message(STATUS "Disabling TF32 instances")
|
||||
remove_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "OFF")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
|
||||
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
@@ -614,7 +658,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
|
||||
add_compile_options(-fdiagnostics-color=always)
|
||||
endif()
|
||||
|
||||
if(NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
# make check runs the entire set of examples and tests
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL)
|
||||
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
|
||||
@@ -625,7 +669,9 @@ endif()
|
||||
|
||||
|
||||
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
|
||||
# Optimization: Search only in library/src where all instance files actually live
|
||||
# (was searching entire source tree, taking ~40s instead of <1s)
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
|
||||
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
@@ -646,6 +692,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
|
||||
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
@@ -667,12 +716,18 @@ ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
|
||||
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
if (CK_EXPERIMENTAL_BUILDER)
|
||||
add_subdirectory(experimental/builder)
|
||||
add_subdirectory(experimental/grouped_convolution_tile_instances)
|
||||
endif()
|
||||
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
@@ -695,7 +750,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
@@ -703,10 +758,6 @@ if (NOT MIOPEN_REQ_LIBS_ONLY)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
|
||||
if (CK_EXPERIMENTAL_BUILDER)
|
||||
add_subdirectory(experimental/builder)
|
||||
endif()
|
||||
|
||||
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
@@ -739,6 +790,16 @@ rocm_install(FILES
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/
|
||||
)
|
||||
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
rocm_install(DIRECTORY
|
||||
${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile
|
||||
)
|
||||
|
||||
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
|
||||
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
|
||||
endif()
|
||||
|
||||
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
|
||||
set(CPACK_RPM_PACKAGE_LICENSE "MIT")
|
||||
|
||||
|
||||
91
CMakePresets.json
Normal file
91
CMakePresets.json
Normal file
@@ -0,0 +1,91 @@
|
||||
{
|
||||
"version": 3,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 21,
|
||||
"patch": 0
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "use-gfx908",
|
||||
"hidden": true,
|
||||
"cacheVariables": {
|
||||
"GPU_TARGETS": "gfx908"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "use-gfx90a",
|
||||
"hidden": true,
|
||||
"cacheVariables": {
|
||||
"GPU_TARGETS": "gfx90a"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "use-gfx942",
|
||||
"hidden": true,
|
||||
"cacheVariables": {
|
||||
"GPU_TARGETS": "gfx942"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "use-gfx950",
|
||||
"hidden": true,
|
||||
"cacheVariables": {
|
||||
"GPU_TARGETS": "gfx950"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "dev",
|
||||
"binaryDir": "${sourceDir}/build",
|
||||
"displayName": "CK Dev",
|
||||
"environment": {},
|
||||
"cacheVariables": {
|
||||
"CMAKE_PREFIX_PATH": "/opt/rocm/",
|
||||
"CMAKE_CXX_COMPILER": "/opt/rocm/llvm/bin/clang++",
|
||||
"CMAKE_HIP_COMPILER": "/opt/rocm/llvm/bin/clang++",
|
||||
"CMAKE_CXX_FLAGS": "-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=1024",
|
||||
"CMAKE_BUILD_TYPE": "Release",
|
||||
"BUILD_DEV": "ON",
|
||||
"CMAKE_VERBOSE_MAKEFILE": "ON",
|
||||
"USE_BITINT_EXTENSION_INT4": "OFF",
|
||||
"GPU_TARGETS": "gfx908;gfx90a;gfx942"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "dev-gfx908",
|
||||
"displayName": "CK Dev - gfx908",
|
||||
"description": "Development build for AMD GPU gfx908",
|
||||
"inherits": [
|
||||
"use-gfx908",
|
||||
"dev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "dev-gfx90a",
|
||||
"displayName": "CK Dev - gfx90a",
|
||||
"description": "Development build for AMD GPU gfx90a",
|
||||
"inherits": [
|
||||
"use-gfx90a",
|
||||
"dev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "dev-gfx942",
|
||||
"displayName": "CK Dev - gfx942",
|
||||
"description": "Development build for AMD GPU gfx942",
|
||||
"inherits": [
|
||||
"use-gfx942",
|
||||
"dev"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "dev-gfx950",
|
||||
"displayName": "CK Dev - gfx950",
|
||||
"description": "Development build for AMD GPU gfx950",
|
||||
"inherits": [
|
||||
"use-gfx950",
|
||||
"dev"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
|
||||
FROM ubuntu:24.04
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.0.1
|
||||
ARG ROCMVERSION=7.1.1
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
@@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN set -xe && \
|
||||
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
|
||||
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
|
||||
apt update && \
|
||||
apt install python3-setuptools python3-wheel -y && \
|
||||
apt install rocm-dev -y
|
||||
|
||||
@@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest"
|
||||
FROM $BASE_DOCKER
|
||||
ARG AITER_BRANCH="main"
|
||||
ARG CK_AITER_BRANCH="develop"
|
||||
RUN pip install pandas zmq einops ninja && \
|
||||
RUN pip install pandas zmq einops ninja tabulate && \
|
||||
pip install numpy==1.26.2 && \
|
||||
sudo mkdir /home/jenkins && \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1"
|
||||
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1"
|
||||
FROM $BASE_DOCKER
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
|
||||
101
Dockerfile.manylinux
Normal file
101
Dockerfile.manylinux
Normal file
@@ -0,0 +1,101 @@
|
||||
FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest
|
||||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
ARG ROCMVERSION=7.2
|
||||
ARG compiler_version=""
|
||||
ARG compiler_commit=""
|
||||
ARG CK_SCCACHE=""
|
||||
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
USER root
|
||||
|
||||
# Add rocm repository
|
||||
RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \
|
||||
dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \
|
||||
dnf update -y && \
|
||||
dnf install python3-setuptools python3-wheel -y && \
|
||||
dnf install rocm-dev -y
|
||||
|
||||
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
|
||||
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
|
||||
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
|
||||
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
|
||||
ENV CK_SCCACHE=$CK_SCCACHE
|
||||
RUN if [ "$CK_SCCACHE" != "" ]; then \
|
||||
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
|
||||
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
|
||||
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
|
||||
fi
|
||||
|
||||
# Install dependencies
|
||||
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
|
||||
cmake \
|
||||
clang-tools-extra \
|
||||
gcc-c++ \
|
||||
libstdc++ \
|
||||
libstdc++-devel \
|
||||
libstdc++-static \
|
||||
git \
|
||||
hip-rocclr \
|
||||
jq \
|
||||
mpich \
|
||||
net-tools \
|
||||
pkg-config \
|
||||
redis \
|
||||
sshpass \
|
||||
stunnel \
|
||||
vim \
|
||||
nano \
|
||||
zip \
|
||||
openssh-server \
|
||||
kmod && \
|
||||
dnf clean all && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -rf amdgpu-install* && \
|
||||
#Install latest ccache
|
||||
git clone https://github.com/ccache/ccache.git && \
|
||||
cd ccache && mkdir build && cd build && cmake .. && make install && \
|
||||
#Install ClangBuildAnalyzer
|
||||
git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \
|
||||
cd ClangBuildAnalyzer/ && \
|
||||
make -f projects/make/Makefile && \
|
||||
cd / && \
|
||||
#Install latest cppcheck
|
||||
git clone https://github.com/danmar/cppcheck.git && \
|
||||
cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \
|
||||
cd / && \
|
||||
# Install packages for processing the performance results
|
||||
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
|
||||
# Add render group
|
||||
groupadd -f render && \
|
||||
# Install the new rocm-cmake version
|
||||
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
|
||||
cd rocm-cmake && mkdir build && cd build && \
|
||||
cmake .. && cmake --build . && cmake --build . --target install
|
||||
|
||||
WORKDIR /
|
||||
# Add alternative compilers, if necessary
|
||||
ENV compiler_version=$compiler_version
|
||||
ENV compiler_commit=$compiler_commit
|
||||
RUN sh -c "echo compiler version = '$compiler_version'" && \
|
||||
sh -c "echo compiler commit = '$compiler_commit'"
|
||||
|
||||
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
|
||||
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
|
||||
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
|
||||
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
|
||||
make -j 8 ; \
|
||||
else echo "using the release compiler"; \
|
||||
fi
|
||||
|
||||
@@ -20,4 +20,13 @@ RUN groupadd -g 109 render && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
chown -R jenkins:jenkins /tmp/pytorch && \
|
||||
chmod -R a+rwx /tmp/pytorch && \
|
||||
sudo usermod -aG irc jenkins
|
||||
sudo usermod -aG irc jenkins && \
|
||||
#install hipblaslt
|
||||
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
git checkout develop && \
|
||||
git sparse-checkout init --cone && \
|
||||
git sparse-checkout set projects/hipblaslt shared/origami && \
|
||||
cd projects/hipblaslt && \
|
||||
git show --oneline -s && \
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
|
||||
645
Jenkinsfile
vendored
645
Jenkinsfile
vendored
File diff suppressed because it is too large
Load Diff
18
README.md
18
README.md
@@ -137,6 +137,22 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
```
|
||||
**[See Note on -j](#notes)**
|
||||
|
||||
### Building for Windows
|
||||
|
||||
Install TheRock and run CMake configure as
|
||||
|
||||
```bash
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH="C:/dist/TheRock" \
|
||||
-D CMAKE_CXX_COMPILER="C:/dist/TheRock/bin/hipcc.exe" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx1151" \
|
||||
-G Ninja \
|
||||
..
|
||||
```
|
||||
|
||||
Use Ninja to build either the whole library or individual targets.
|
||||
|
||||
## Optional post-install steps
|
||||
|
||||
* Build examples and tests:
|
||||
@@ -187,7 +203,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb
|
||||
|
||||
Additional cmake flags can be used to significantly speed-up the build:
|
||||
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
|
||||
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
|
||||
other data types.
|
||||
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_gemm gemm.cpp)
|
||||
target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_custom_target(client_gemm_fastgelu_examples)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
|
||||
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
|
||||
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp)
|
||||
target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_softmax4d softmax4d.cpp)
|
||||
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_fused_attention fused_attention.cpp)
|
||||
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp)
|
||||
target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
|
||||
add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp)
|
||||
add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp)
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp)
|
||||
target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
|
||||
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp)
|
||||
target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp)
|
||||
target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
|
||||
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp)
|
||||
target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
# Fwd scaleadd scaleadd relu
|
||||
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
|
||||
|
||||
@@ -1,2 +1,5 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp)
|
||||
target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_reduction_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(client_image_to_column image_to_column.cpp)
|
||||
target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx950")
|
||||
add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp)
|
||||
target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
|
||||
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
|
||||
add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp)
|
||||
target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++20)
|
||||
@@ -24,6 +27,9 @@ if (DTYPES)
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "tf32")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
@@ -38,6 +44,7 @@ else()
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
if (GPU_TARGETS MATCHES "gfx94")
|
||||
@@ -64,6 +71,14 @@ if (GPU_TARGETS)
|
||||
add_definitions(-DCK_USE_FNUZ_FP8)
|
||||
set(CK_USE_FNUZ_FP8 "ON")
|
||||
endif()
|
||||
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
|
||||
add_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "ON")
|
||||
else()
|
||||
message(STATUS "Disabling TF32 instances for this target")
|
||||
remove_definitions(-DCK_ENABLE_TF32)
|
||||
set(CK_ENABLE_TF32 "OFF")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
|
||||
@@ -35,7 +35,7 @@ function(generate_sharded_instantiations)
|
||||
set(GENERATED_SOURCE_FILES "")
|
||||
set(EXTERN_TEMPLATE_STATEMENTS "")
|
||||
set(CALL_STATEMENTS "")
|
||||
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
|
||||
message(DEBUG "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}")
|
||||
|
||||
|
||||
@@ -68,6 +68,9 @@ set(GTEST_CXX_FLAGS
|
||||
-Wno-deprecated
|
||||
-Wno-unsafe-buffer-usage
|
||||
-Wno-float-equal
|
||||
-Wno-lifetime-safety-intra-tu-suggestions
|
||||
-Wno-lifetime-safety-cross-tu-suggestions
|
||||
-Wno-character-conversion
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(composable_kernel_host)
|
||||
|
||||
@@ -12,6 +15,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
|
||||
find_package(ROCM)
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMTest)
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH})
|
||||
find_package(hiprtc REQUIRED)
|
||||
|
||||
rocm_setup_version(VERSION 1.0)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
add_subdirectory(rtc)
|
||||
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
find_package(hip)
|
||||
file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
add_library(ck_rtc ${RTC_SOURCES})
|
||||
|
||||
117
dispatcher/CMakeLists.txt
Normal file
117
dispatcher/CMakeLists.txt
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
project(ck_tile_dispatcher VERSION 1.0.0 LANGUAGES CXX)
|
||||
|
||||
# C++17 required
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
# Find HIP for headers (needed for validation kernels)
|
||||
find_package(hip QUIET)
|
||||
if(NOT hip_FOUND)
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip)
|
||||
find_package(hip REQUIRED)
|
||||
endif()
|
||||
|
||||
# Dispatcher library
|
||||
add_library(ck_tile_dispatcher
|
||||
src/registry.cpp
|
||||
src/dispatcher.cpp
|
||||
)
|
||||
|
||||
# Enable PIC for Python bindings
|
||||
set_target_properties(ck_tile_dispatcher PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
)
|
||||
|
||||
target_include_directories(ck_tile_dispatcher
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
# Link against CK Tile headers (header-only)
|
||||
target_include_directories(ck_tile_dispatcher
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
# Link against HIP headers if available
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(ck_tile_dispatcher PUBLIC hip::host)
|
||||
endif()
|
||||
|
||||
# Compiler warnings
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
|
||||
target_compile_options(ck_tile_dispatcher PRIVATE
|
||||
-Wall -Wextra -Wpedantic
|
||||
)
|
||||
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
|
||||
target_compile_options(ck_tile_dispatcher PRIVATE
|
||||
/W4
|
||||
)
|
||||
endif()
|
||||
|
||||
# Optional: Build tests
|
||||
option(BUILD_DISPATCHER_TESTS "Build dispatcher unit tests" OFF)
|
||||
if(BUILD_DISPATCHER_TESTS)
|
||||
enable_testing()
|
||||
add_subdirectory(tests)
|
||||
endif()
|
||||
|
||||
# Optional: Build Python bindings
|
||||
option(BUILD_DISPATCHER_PYTHON "Build Python bindings for dispatcher" OFF)
|
||||
if(BUILD_DISPATCHER_PYTHON)
|
||||
add_subdirectory(python)
|
||||
endif()
|
||||
|
||||
# Optional: Codegen for tile_engine integration
|
||||
option(DISPATCHER_AUTO_GENERATE_WRAPPERS "Auto-generate wrappers from tile_engine" OFF)
|
||||
if(DISPATCHER_AUTO_GENERATE_WRAPPERS)
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
|
||||
# Optional: Build examples
|
||||
option(BUILD_DISPATCHER_EXAMPLES "Build dispatcher examples" OFF)
|
||||
if(BUILD_DISPATCHER_EXAMPLES)
|
||||
add_subdirectory(examples)
|
||||
endif()
|
||||
|
||||
# Optional: Build ctypes bindings
|
||||
option(BUILD_DISPATCHER_BINDINGS "Build language bindings for dispatcher" OFF)
|
||||
if(BUILD_DISPATCHER_BINDINGS)
|
||||
add_subdirectory(bindings/ctypes)
|
||||
endif()
|
||||
|
||||
# If codegen is enabled, add generated include directory
|
||||
if(DISPATCHER_AUTO_GENERATE_WRAPPERS AND DISPATCHER_GENERATED_INCLUDE_DIR)
|
||||
target_include_directories(ck_tile_dispatcher
|
||||
PUBLIC
|
||||
$<BUILD_INTERFACE:${DISPATCHER_GENERATED_INCLUDE_DIR}>
|
||||
)
|
||||
endif()
|
||||
|
||||
# Installation
|
||||
install(TARGETS ck_tile_dispatcher
|
||||
EXPORT ck_tile_dispatcher_targets
|
||||
LIBRARY DESTINATION lib
|
||||
ARCHIVE DESTINATION lib
|
||||
RUNTIME DESTINATION bin
|
||||
)
|
||||
|
||||
install(DIRECTORY include/
|
||||
DESTINATION include
|
||||
FILES_MATCHING PATTERN "*.hpp"
|
||||
)
|
||||
|
||||
install(EXPORT ck_tile_dispatcher_targets
|
||||
FILE ck_tile_dispatcher_targets.cmake
|
||||
NAMESPACE ck_tile::
|
||||
DESTINATION lib/cmake/ck_tile_dispatcher
|
||||
)
|
||||
|
||||
736
dispatcher/README.md
Normal file
736
dispatcher/README.md
Normal file
@@ -0,0 +1,736 @@
|
||||
# CK Tile Dispatcher
|
||||
|
||||
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends.
|
||||
|
||||
**Validated Platform:** AMD Instinct MI300 series (gfx942)
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Quick Start](#quick-start)
|
||||
2. [Docker Setup](#docker-setup-recommended)
|
||||
3. [Prerequisites](#prerequisites)
|
||||
4. [Step-by-Step Build Guide](#step-by-step-build-guide)
|
||||
5. [Running Examples](#running-examples)
|
||||
6. [External Integration](#external-integration)
|
||||
7. [Core Concepts](#core-concepts)
|
||||
8. [Troubleshooting](#troubleshooting)
|
||||
9. [File Structure](#file-structure)
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
**Complete setup from scratch (5 minutes):**
|
||||
|
||||
```bash
|
||||
# From the composable_kernel root directory
|
||||
cd dispatcher
|
||||
|
||||
# Step 1: Create build directory
|
||||
mkdir -p build && cd build
|
||||
|
||||
# Step 2: Configure CMake
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Step 3: Generate kernels and build (CMake handles this automatically)
|
||||
make -j$(nproc)
|
||||
|
||||
# Step 4: Run C++ examples
|
||||
./examples/gemm_01_basic
|
||||
|
||||
# Step 5: Build Python libraries (required for Python examples)
|
||||
make python_libs
|
||||
|
||||
# Step 6: Run Python examples (from dispatcher directory)
|
||||
cd ..
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Docker Setup (Recommended)
|
||||
|
||||
For a reproducible build environment, use the official ROCm Docker image:
|
||||
|
||||
### Step 1: Pull and Run Container
|
||||
|
||||
```bash
|
||||
# Pull the CK Docker image
|
||||
docker pull rocm/composable_kernel:ck_ub24.04_rocm7.0.1
|
||||
|
||||
# Run container with GPU access
|
||||
docker run \
|
||||
-it \
|
||||
--privileged \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--group-add render \
|
||||
-w /root/workspace \
|
||||
-v $(pwd):/root/workspace \
|
||||
rocm/composable_kernel:ck_ub24.04_rocm7.0.1 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
> **Note:** Omit `--device` flags if building without GPU access.
|
||||
|
||||
### Step 2: Clone and Build
|
||||
|
||||
```bash
|
||||
# Inside the container
|
||||
git clone https://github.com/ROCm/composable_kernel.git
|
||||
cd composable_kernel
|
||||
git checkout builder-dispatch-tile-gemm
|
||||
|
||||
# Set up Python environment
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install numpy
|
||||
|
||||
# Build dispatcher
|
||||
cd dispatcher
|
||||
mkdir -p build && cd build
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
make -j$(nproc)
|
||||
```
|
||||
|
||||
### One-Liner Build (inside container)
|
||||
|
||||
```bash
|
||||
git clone https://github.com/ROCm/composable_kernel.git && \
|
||||
cd composable_kernel && git checkout builder-dispatch-tile-gemm && \
|
||||
python3 -m venv .venv && source .venv/bin/activate && pip install numpy && \
|
||||
cd dispatcher && mkdir -p build && cd build && \
|
||||
cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS="gfx942" -DBUILD_DISPATCHER_EXAMPLES=ON && \
|
||||
make -j$(nproc)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Required Software
|
||||
|
||||
| Software | Minimum Version | Check Command |
|
||||
|----------|-----------------|---------------|
|
||||
| ROCm | 6.4+ | `rocminfo` |
|
||||
| CMake | 3.16+ | `cmake --version` |
|
||||
| Python | 3.8+ | `python3 --version` |
|
||||
| NumPy | 1.20+ | `pip show numpy` |
|
||||
| hipcc | (from ROCm) | `/opt/rocm/bin/hipcc --version` |
|
||||
|
||||
> **Note:** Newer GPU targets (gfx950, gfx1201) require ROCm 6.3+. For ROCm 6.4+, you can also use `amdclang++` instead of `hipcc`.
|
||||
|
||||
### Check Your GPU Architecture
|
||||
|
||||
```bash
|
||||
# Find your GPU architecture
|
||||
rocminfo | grep -i "gfx"
|
||||
# Example output: "gfx942"
|
||||
```
|
||||
|
||||
**Supported architectures:**
|
||||
- **gfx942** - MI300X, MI300A, MI308, MI325 (Instinct MI300 series)
|
||||
- **gfx90a** - MI200 series (MI250, MI250X)
|
||||
- **gfx950** - MI350 series
|
||||
- **gfx1101** - RDNA3 series
|
||||
- **gfx1201** - RDNA4 series
|
||||
|
||||
### Install Python Dependencies
|
||||
|
||||
NumPy is required for Python examples and kernel generation. We recommend using a virtual environment:
|
||||
|
||||
**Option 1: Using standard venv**
|
||||
```bash
|
||||
# Create virtual environment
|
||||
python3 -m venv .venv
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
pip install numpy
|
||||
```
|
||||
|
||||
**Option 2: Using uv (faster alternative)**
|
||||
```bash
|
||||
# Install uv if not already installed
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Create and activate virtual environment
|
||||
uv venv .venv
|
||||
source .venv/bin/activate # Linux/macOS
|
||||
# .venv\Scripts\activate # Windows
|
||||
|
||||
# Install NumPy
|
||||
uv pip install numpy
|
||||
```
|
||||
|
||||
**Option 3: System-wide install (not recommended)**
|
||||
```bash
|
||||
pip install numpy
|
||||
```
|
||||
|
||||
> **Note:** Always activate your virtual environment before running CMake or Python examples.
|
||||
|
||||
### Supported Data Types
|
||||
|
||||
CK Tile supports a wide range of data types for GEMM operations:
|
||||
|
||||
| A dtype | B dtype | Acc dtype | Warp Tile Sizes | Notes |
|
||||
|---------|---------|-----------|-----------------|-------|
|
||||
| `fp32` | `fp32` | `fp32` | 16x16x4, 16x16x16 | Full precision |
|
||||
| `fp16` | `fp16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Standard half |
|
||||
| `bf16` | `bf16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Brain float 16 |
|
||||
| `fp8` | `fp8` | `fp32` | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | FP8 E4M3 |
|
||||
| `fp8` | `bf8` | `fp32` | 32x32x16, 16x16x32 | Mixed FP8/BF8 |
|
||||
| `bf8` | `fp8` | `fp32` | 32x32x16, 16x16x128 | Mixed BF8/FP8 |
|
||||
| `bf8` | `bf8` | `fp32` | 32x32x16, 32x32x32, 16x16x32 | BF8 E5M2 |
|
||||
| `int8` | `int8` | `int32` | 32x32x16, 16x16x32, 16x16x16 | Integer GEMM |
|
||||
| `pk_fp4` | `pk_fp4` | `fp32` | 16x16x128 | Packed 4-bit float |
|
||||
|
||||
**Notes:**
|
||||
- Accumulator is always `fp32` except for `int8` which uses `int32`
|
||||
- FP8 types: `fp8` = E4M3, `bf8` = E5M2
|
||||
- `pk_fp4` = Packed 4-bit float (2 values per byte)
|
||||
- Some dtypes require specific GPU architectures (e.g., FP8 requires MI300+)
|
||||
|
||||
---
|
||||
|
||||
## Step-by-Step Build Guide
|
||||
|
||||
### Step 1: Navigate to Dispatcher Directory
|
||||
|
||||
```bash
|
||||
# From composable_kernel root
|
||||
cd dispatcher
|
||||
|
||||
# Verify you're in the right place
|
||||
ls CMakeLists.txt # Should exist
|
||||
```
|
||||
|
||||
### Step 2: Create Build Directory
|
||||
|
||||
```bash
|
||||
mkdir -p build
|
||||
cd build
|
||||
```
|
||||
|
||||
### Step 3: Configure CMake
|
||||
|
||||
**Basic configuration (library only):**
|
||||
```bash
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx942"
|
||||
```
|
||||
|
||||
**Full configuration (with examples and tests):**
|
||||
```bash
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON \
|
||||
-DBUILD_DISPATCHER_TESTS=ON
|
||||
```
|
||||
|
||||
**Expected output:**
|
||||
```
|
||||
-- Found hip: /opt/rocm (found suitable version "6.x.x")
|
||||
-- Generating GEMM kernels...
|
||||
-- Built: gemm_01 through gemm_06, dispatcher_gemm_lib.so
|
||||
-- Configuring done
|
||||
```
|
||||
|
||||
### Step 4: Build
|
||||
|
||||
```bash
|
||||
# Build all targets (generates kernels automatically, then compiles)
|
||||
make -j$(nproc)
|
||||
|
||||
# Or build specific targets
|
||||
make gemm_01_basic # Single GEMM example
|
||||
make dispatcher_gemm_lib # GEMM shared library for Python
|
||||
|
||||
# Build ONLY Python libraries (faster if you don't need C++ examples)
|
||||
make python_libs -j$(nproc)
|
||||
```
|
||||
|
||||
### Kernel Generation Targets
|
||||
|
||||
Kernels are generated automatically during `make`, but you can also control generation explicitly:
|
||||
|
||||
```bash
|
||||
# Generate all kernels only (no compilation)
|
||||
make generate_all_kernels
|
||||
|
||||
# Generate GEMM kernels only
|
||||
make generate_gemm_kernels
|
||||
|
||||
# Force regenerate (even if kernels exist)
|
||||
make regenerate_all_kernels
|
||||
make regenerate_gemm_kernels
|
||||
|
||||
# Generate for specific GPU architecture
|
||||
make generate_kernels_gfx942 # MI300X
|
||||
make generate_kernels_gfx90a # MI200
|
||||
make generate_kernels_gfx1100 # RDNA3
|
||||
```
|
||||
|
||||
### Step 5: Verify Build
|
||||
|
||||
```bash
|
||||
# Check executables were built
|
||||
ls examples/gemm_*
|
||||
|
||||
# Check shared libraries were built
|
||||
ls examples/libdispatcher_gemm_lib.so
|
||||
```
|
||||
|
||||
### CMake Options Reference
|
||||
|
||||
| Flag | Default | Description |
|
||||
|------|---------|-------------|
|
||||
| `CMAKE_BUILD_TYPE` | Debug | **Use `Release` for performance!** |
|
||||
| `GPU_TARGETS` | None | Target GPU: `"gfx942"`, `"gfx90a"`, etc. |
|
||||
| `BUILD_DISPATCHER_EXAMPLES` | OFF | Build C++ examples and Python libs |
|
||||
| `BUILD_DISPATCHER_TESTS` | OFF | Build unit tests |
|
||||
| `CMAKE_PREFIX_PATH` | - | ROCm installation path |
|
||||
| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler |
|
||||
|
||||
⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
|
||||
⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
|
||||
|
||||
---
|
||||
|
||||
## Running Examples
|
||||
|
||||
### C++ Examples
|
||||
|
||||
After building, executables are in `build/examples/`:
|
||||
|
||||
```bash
|
||||
cd build/examples
|
||||
|
||||
# GEMM Examples
|
||||
./gemm_01_basic # Basic GEMM with autofill/autocorrect
|
||||
./gemm_02_multi_size # Wildcard expansion
|
||||
./gemm_03_benchmark_validation # Benchmarking + validation
|
||||
./gemm_04_heuristics # Heuristic kernel selection
|
||||
./gemm_05_json_export # Registry JSON export
|
||||
./gemm_06_multi_registry # Multiple registries
|
||||
```
|
||||
|
||||
### Python Examples
|
||||
|
||||
Run from the `dispatcher` directory:
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
|
||||
# GEMM Examples
|
||||
python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM
|
||||
python3 examples/gemm/python/04_validation.py # CPU reference validation
|
||||
python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels)
|
||||
python3 examples/gemm/python/08_heuristics.py # Heuristic selection
|
||||
```
|
||||
|
||||
### Example Output
|
||||
|
||||
**Expected C++ output (`gemm_01_basic`):**
|
||||
```
|
||||
======================================================================
|
||||
Example 01: Basic GEMM with Declarative Kernel Definition
|
||||
======================================================================
|
||||
|
||||
Step 1: Declared Kernels
|
||||
------------------------
|
||||
Kernel Set: fp16_gemm_kernels
|
||||
Architecture: gfx942
|
||||
Configurations: 1
|
||||
- gemm_fp16_rcr_compv4_cshuffle_intrawave_128x128x32
|
||||
|
||||
Step 2: Create Registry and Dispatcher
|
||||
--------------------------------------
|
||||
Registered 1 kernels
|
||||
|
||||
Step 3: Define Problem
|
||||
----------------------
|
||||
M=1024, N=1024, K=1024
|
||||
|
||||
Step 4: GPU Execution
|
||||
---------------------
|
||||
*** GPU EXECUTION ***
|
||||
Time: <varies> ms
|
||||
TFLOPS: <varies>
|
||||
```
|
||||
|
||||
> **Note:** Timing values vary by GPU model and system configuration.
|
||||
|
||||
---
|
||||
|
||||
## Benchmark Parameters
|
||||
|
||||
The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`:
|
||||
|
||||
### Available Parameters
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-----------|------|---------|-------------|
|
||||
| `warmup` | int | 5 | Warmup iterations (discarded from timing) |
|
||||
| `repeat` | int | 20 | Benchmark iterations (averaged) |
|
||||
| `flush_cache` | bool | false | Flush GPU L2 cache between iterations |
|
||||
| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) |
|
||||
| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" |
|
||||
| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" |
|
||||
| `split_k` | int | 1 | Split-K parallelism factor |
|
||||
|
||||
### Python Usage
|
||||
|
||||
```python
|
||||
from ctypes_utils import DispatcherLib
|
||||
|
||||
# Basic usage (default benchmark settings)
|
||||
lib = DispatcherLib.load()
|
||||
|
||||
# Advanced benchmark settings via command line
|
||||
python3 examples/gemm/python/10_advanced_benchmark.py \
|
||||
--warmup 10 \
|
||||
--repeat 100 \
|
||||
--flush-cache
|
||||
```
|
||||
|
||||
### C++ Usage
|
||||
|
||||
```cpp
|
||||
// Basic timing
|
||||
ck_tile::stream_config cfg{nullptr, true};
|
||||
|
||||
// Advanced benchmark settings
|
||||
ck_tile::stream_config cfg{
|
||||
nullptr, // stream_id (nullptr = default stream)
|
||||
true, // time_kernel
|
||||
1, // log_level
|
||||
10, // cold_niters (warmup)
|
||||
100, // nrepeat
|
||||
true, // is_gpu_timer
|
||||
true, // flush_cache
|
||||
4 // rotating_count
|
||||
};
|
||||
|
||||
float avg_time = kernel.run(args, cfg);
|
||||
```
|
||||
|
||||
### Command Line (Python Examples)
|
||||
|
||||
```bash
|
||||
# Basic run
|
||||
python3 examples/gemm/python/10_advanced_benchmark.py
|
||||
|
||||
# With benchmark parameters
|
||||
python3 examples/gemm/python/10_advanced_benchmark.py \
|
||||
--warmup 10 \
|
||||
--repeat 100 \
|
||||
--flush-cache \
|
||||
--rotating-count 4 \
|
||||
--timer gpu
|
||||
```
|
||||
|
||||
### When to Use Each Parameter
|
||||
|
||||
| Use Case | Recommended Settings |
|
||||
|----------|---------------------|
|
||||
| Quick test | `warmup=1, repeat=3` |
|
||||
| Stable benchmark | `warmup=10, repeat=100` |
|
||||
| Memory-bound analysis | `flush_cache=True, rotating_count=4` |
|
||||
| Compute-bound analysis | `flush_cache=False` (default) |
|
||||
| Debug timing | `timer="cpu"` |
|
||||
| Production | `timer="gpu"` (default) |
|
||||
|
||||
---
|
||||
|
||||
## External Integration
|
||||
|
||||
### Using Dispatcher in Your Own Project
|
||||
|
||||
#### Option 1: CMake Integration (Recommended)
|
||||
|
||||
Add to your `CMakeLists.txt`:
|
||||
|
||||
```cmake
|
||||
# Set path to composable_kernel
|
||||
set(CK_ROOT "/path/to/composable_kernel")
|
||||
|
||||
# Add dispatcher subdirectory
|
||||
add_subdirectory(${CK_ROOT}/dispatcher dispatcher_build)
|
||||
|
||||
# Link to your target
|
||||
target_link_libraries(your_target PRIVATE ck_tile_dispatcher)
|
||||
target_include_directories(your_target PRIVATE
|
||||
${CK_ROOT}/dispatcher/include
|
||||
${CK_ROOT}/include
|
||||
)
|
||||
```
|
||||
|
||||
#### Option 2: Include as Pre-built Library
|
||||
|
||||
```cmake
|
||||
# Find the pre-built library
|
||||
find_library(CK_DISPATCHER ck_tile_dispatcher
|
||||
PATHS /path/to/composable_kernel/dispatcher/build)
|
||||
|
||||
# Include directories
|
||||
set(CK_INCLUDE_DIRS
|
||||
/path/to/composable_kernel/include
|
||||
/path/to/composable_kernel/dispatcher/include
|
||||
)
|
||||
|
||||
target_link_libraries(your_target PRIVATE ${CK_DISPATCHER})
|
||||
target_include_directories(your_target PRIVATE ${CK_INCLUDE_DIRS})
|
||||
```
|
||||
|
||||
#### Option 3: Python Integration
|
||||
|
||||
```python
|
||||
import sys
|
||||
sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python")
|
||||
|
||||
# For GEMM
|
||||
from ctypes_utils import DispatcherLib, Dispatcher, KernelConfig
|
||||
```
|
||||
|
||||
### Required Include Paths
|
||||
|
||||
When integrating, you need these include paths:
|
||||
|
||||
```
|
||||
/path/to/composable_kernel/include # CK Tile core headers
|
||||
/path/to/composable_kernel/dispatcher/include # Dispatcher headers
|
||||
/path/to/composable_kernel/dispatcher/build/generated_kernels # Generated kernels
|
||||
```
|
||||
|
||||
### Required Compile Flags
|
||||
|
||||
```bash
|
||||
# Minimum flags for hipcc
|
||||
-std=c++17
|
||||
-D__HIP_PLATFORM_AMD__=1
|
||||
--offload-arch=gfx942 # Your target GPU
|
||||
|
||||
# Recommended flags
|
||||
-O3
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
-Wall
|
||||
-Werror
|
||||
```
|
||||
|
||||
### Python Path Setup
|
||||
|
||||
For Python scripts outside the dispatcher directory:
|
||||
|
||||
```bash
|
||||
# Option 1: Environment variable
|
||||
export PYTHONPATH="/path/to/composable_kernel/dispatcher/examples/gemm/python:$PYTHONPATH"
|
||||
|
||||
# Option 2: In your Python script
|
||||
import sys
|
||||
sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python")
|
||||
```
|
||||
|
||||
### Library Search Paths
|
||||
|
||||
The Python utilities search for the shared library in these locations:
|
||||
|
||||
```python
|
||||
# For GEMM (ctypes_utils.py)
|
||||
SEARCH_PATHS = [
|
||||
"build/examples/libdispatcher_gemm_lib.so",
|
||||
"../build/examples/libdispatcher_gemm_lib.so",
|
||||
"../../build/examples/libdispatcher_gemm_lib.so",
|
||||
]
|
||||
```
|
||||
|
||||
If using from a different location, set the library path explicitly:
|
||||
|
||||
```python
|
||||
# GEMM
|
||||
from ctypes_utils import DispatcherLib
|
||||
lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
KernelConfig → Registry → Dispatcher → GPU Execution
|
||||
```
|
||||
|
||||
1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts)
|
||||
2. **Registry**: Stores multiple kernel configurations
|
||||
3. **Dispatcher**: Selects best kernel for a given problem and executes it
|
||||
|
||||
### GEMM Layouts
|
||||
|
||||
| Layout | A | B | C | Use Case |
|
||||
|--------|---|---|---|----------|
|
||||
| RCR | Row | Col | Row | Most common (PyTorch default) |
|
||||
| RRR | Row | Row | Row | Both inputs row-major |
|
||||
| CRR | Col | Row | Row | A transposed |
|
||||
| CCR | Col | Col | Row | Both inputs column-major |
|
||||
|
||||
### Split-K Support
|
||||
|
||||
Split-K divides the K dimension across multiple thread blocks, useful for large K dimensions.
|
||||
|
||||
**Usage (C++):**
|
||||
```cpp
|
||||
// GEMM with 4-way K split
|
||||
auto problem = ProblemBuilder()
|
||||
.m(1024).n(1024).k(8192)
|
||||
.split_k(4)
|
||||
.build();
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Build Issues
|
||||
|
||||
| Problem | Solution |
|
||||
|---------|----------|
|
||||
| `hipcc not found` | Set `-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc` |
|
||||
| `hip not found` | Set `-DCMAKE_PREFIX_PATH=/opt/rocm` |
|
||||
| Very slow performance | Use `-DCMAKE_BUILD_TYPE=Release` |
|
||||
| `gfx942 not supported` | Check ROCm version (need 6.0+) |
|
||||
| Kernel generation fails | Ensure Python 3.8+ with NumPy installed in active venv |
|
||||
| Build errors | First verify CK builds without dispatcher (see main CK README) |
|
||||
|
||||
### Runtime Issues
|
||||
|
||||
| Problem | Solution |
|
||||
|---------|----------|
|
||||
| `Library not found` | Build with `-DBUILD_DISPATCHER_EXAMPLES=ON` |
|
||||
| `No kernel found` | Check GPU arch matches build target |
|
||||
| Python `ModuleNotFoundError` | Add paths to `PYTHONPATH` (see above) |
|
||||
| Wrong results | Verify layout matches your data |
|
||||
|
||||
### Debug Commands
|
||||
|
||||
```bash
|
||||
# Check ROCm installation
|
||||
rocminfo | head -20
|
||||
|
||||
# Check GPU architecture
|
||||
rocminfo | grep "Name:"
|
||||
|
||||
# Verify library exists
|
||||
ls -la build/examples/libdispatcher_*.so
|
||||
|
||||
# Run with verbose output
|
||||
./build/examples/gemm_01_basic 2>&1
|
||||
|
||||
# Python: Check library loading
|
||||
python3 -c "
|
||||
import ctypes
|
||||
lib = ctypes.CDLL('/path/to/libdispatcher_gemm_lib.so')
|
||||
print('Library loaded successfully')
|
||||
"
|
||||
```
|
||||
|
||||
### Clean Rebuild
|
||||
|
||||
If you encounter issues, try a clean rebuild:
|
||||
|
||||
```bash
|
||||
cd dispatcher
|
||||
rm -rf build
|
||||
mkdir build && cd build
|
||||
cmake .. [your options]
|
||||
make -j$(nproc)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
dispatcher/
|
||||
├── README.md # This file
|
||||
├── CMakeLists.txt # Build configuration
|
||||
│
|
||||
├── include/ck_tile/dispatcher/ # C++ headers
|
||||
│ ├── dispatcher.hpp # GEMM dispatcher
|
||||
│ ├── registry.hpp # Kernel registry
|
||||
│ └── kernel_key.hpp # Kernel configuration
|
||||
│
|
||||
├── src/ # C++ implementation
|
||||
│
|
||||
├── codegen/ # Kernel generation
|
||||
│ ├── unified_gemm_codegen.py # GEMM kernel generator
|
||||
│ └── arch_specs.json # GPU specifications
|
||||
│
|
||||
├── bindings/ctypes/ # Python ctypes interface
|
||||
│ └── gemm_ctypes_lib.cpp # GEMM Python library
|
||||
│
|
||||
├── examples/ # Examples
|
||||
│ └── gemm/
|
||||
│ ├── cpp/ # C++ GEMM examples (01-06)
|
||||
│ └── python/ # Python GEMM examples (01-11)
|
||||
│
|
||||
├── scripts/ # Build scripts
|
||||
│
|
||||
└── tests/ # Unit tests
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example Documentation
|
||||
|
||||
| Directory | README |
|
||||
|-----------|--------|
|
||||
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
|
||||
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
|
||||
| Codegen | [codegen/README.md](codegen/README.md) |
|
||||
|
||||
---
|
||||
|
||||
## Archived Content
|
||||
|
||||
Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`:
|
||||
- `examples/conv/cpp/` - 11 C++ convolution examples
|
||||
- `examples/conv/python/` - 14 Python convolution examples
|
||||
- `codegen/unified_conv_codegen.py` - Conv kernel generator
|
||||
- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers
|
||||
- `python/conv_utils.py` - Conv Python utilities
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License - Copyright (c) 2025, Advanced Micro Devices, Inc.
|
||||
109
dispatcher/bindings/README.md
Normal file
109
dispatcher/bindings/README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# CK Tile Dispatcher - Language Bindings
|
||||
|
||||
This directory contains language bindings for the CK Tile Dispatcher.
|
||||
|
||||
## Structure
|
||||
|
||||
```
|
||||
bindings/
|
||||
├── ctypes/ # Python ctypes bindings (C API)
|
||||
│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API
|
||||
│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data)
|
||||
│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API
|
||||
│ ├── gpu_helper.cpp # CLI helper for Python
|
||||
│ └── CMakeLists.txt
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## ctypes Bindings
|
||||
|
||||
The ctypes bindings provide a C API that Python can load via `ctypes.CDLL()`.
|
||||
|
||||
### Building
|
||||
|
||||
```bash
|
||||
cd build
|
||||
cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm
|
||||
make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper
|
||||
```
|
||||
|
||||
### Usage from Python
|
||||
|
||||
```python
|
||||
import ctypes
|
||||
|
||||
# Load the library
|
||||
lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so")
|
||||
|
||||
# Initialize
|
||||
lib.dispatcher_init()
|
||||
|
||||
# Check if problem is supported
|
||||
is_supported = lib.dispatcher_is_supported(M, N, K)
|
||||
|
||||
# Run GEMM
|
||||
time_ms = ctypes.c_float()
|
||||
result = lib.dispatcher_run_gemm(
|
||||
A_ptr, B_ptr, C_ptr,
|
||||
M, N, K,
|
||||
ctypes.byref(time_ms)
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
lib.dispatcher_cleanup()
|
||||
```
|
||||
|
||||
### GEMM API
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `dispatcher_init()` | Initialize the dispatcher |
|
||||
| `dispatcher_is_supported(M, N, K)` | Check if problem size is supported |
|
||||
| `dispatcher_select_kernel(M, N, K, name_buf, buf_size)` | Get kernel name for problem |
|
||||
| `dispatcher_run_gemm(A, B, C, M, N, K, time_ms)` | Execute GEMM |
|
||||
| `dispatcher_get_kernel_count()` | Get number of registered kernels |
|
||||
| `dispatcher_export_registry_json()` | Export registry as JSON |
|
||||
| `dispatcher_cleanup()` | Release resources |
|
||||
|
||||
### Convolution API
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `conv_dispatcher_init()` | Initialize the dispatcher |
|
||||
| `conv_dispatcher_is_supported(prob)` | Check if problem is supported |
|
||||
| `conv_dispatcher_select_kernel(prob, name_buf, buf_size)` | Get kernel name |
|
||||
| `conv_dispatcher_run(input, weight, output, prob, stream)` | Execute convolution |
|
||||
| `conv_dispatcher_get_kernel_count()` | Get number of registered kernels |
|
||||
| `conv_dispatcher_cleanup()` | Release resources |
|
||||
|
||||
## GPU Helper
|
||||
|
||||
The `gpu_helper` executable provides a CLI interface for Python:
|
||||
|
||||
```bash
|
||||
./gpu_helper 1024 1024 1024 --validate
|
||||
```
|
||||
|
||||
Output is JSON for easy parsing:
|
||||
```json
|
||||
{
|
||||
"problem": {"M": 1024, "N": 1024, "K": 1024},
|
||||
"kernel": "gemm_fp16_rcr_...",
|
||||
"execution": {
|
||||
"time_ms": 0.5,
|
||||
"tflops": 4.2
|
||||
},
|
||||
"validation": {
|
||||
"accuracy": 100.0
|
||||
},
|
||||
"status": "success"
|
||||
}
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the examples that use these bindings:
|
||||
|
||||
- **GEMM**: `dispatcher/examples/gemm/python/`
|
||||
- **Conv**: `dispatcher/examples/conv/python/`
|
||||
|
||||
181
dispatcher/bindings/ctypes/CMakeLists.txt
Normal file
181
dispatcher/bindings/ctypes/CMakeLists.txt
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# =============================================================================
|
||||
# CK Tile Dispatcher - ctypes Bindings
|
||||
# =============================================================================
|
||||
#
|
||||
# Provides shared libraries with C API for Python ctypes integration.
|
||||
#
|
||||
# Targets:
|
||||
# - dispatcher_gemm_lib : GEMM dispatcher library
|
||||
# - dispatcher_conv_lib : Convolution dispatcher library (forward + bwd_data)
|
||||
# - dispatcher_conv_bwdw_lib : Convolution backward weight library
|
||||
# - gpu_helper : GPU helper executable for Python
|
||||
#
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
# Helper function to add a ctypes library
|
||||
function(add_ctypes_library TARGET_NAME SOURCE_FILE)
|
||||
cmake_parse_arguments(ARG "CONV" "KERNEL_HEADER" "" ${ARGN})
|
||||
|
||||
add_library(${TARGET_NAME} SHARED ${SOURCE_FILE})
|
||||
|
||||
target_include_directories(${TARGET_NAME} PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
|
||||
target_link_libraries(${TARGET_NAME} PRIVATE
|
||||
hip::device
|
||||
)
|
||||
|
||||
# Force-include kernel header if provided
|
||||
if(ARG_KERNEL_HEADER AND EXISTS ${ARG_KERNEL_HEADER})
|
||||
target_compile_options(${TARGET_NAME} PRIVATE
|
||||
-include ${ARG_KERNEL_HEADER}
|
||||
)
|
||||
if(ARG_CONV)
|
||||
target_compile_definitions(${TARGET_NAME} PRIVATE CONV_KERNEL_AVAILABLE)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set_target_properties(${TARGET_NAME} PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# =============================================================================
|
||||
# GEMM ctypes Library
|
||||
# =============================================================================
|
||||
|
||||
# Find a generated GEMM kernel header for the library
|
||||
file(GLOB GEMM_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/gemm_*.hpp")
|
||||
if(GEMM_KERNEL_HEADERS)
|
||||
list(GET GEMM_KERNEL_HEADERS 0 GEMM_KERNEL_HEADER)
|
||||
message(STATUS "Found GEMM kernel for ctypes lib: ${GEMM_KERNEL_HEADER}")
|
||||
|
||||
add_ctypes_library(dispatcher_gemm_lib
|
||||
gemm_ctypes_lib.cpp
|
||||
KERNEL_HEADER ${GEMM_KERNEL_HEADER}
|
||||
)
|
||||
else()
|
||||
message(STATUS "No GEMM kernel found for ctypes lib - building without kernel")
|
||||
add_library(dispatcher_gemm_lib SHARED gemm_ctypes_lib.cpp)
|
||||
target_include_directories(dispatcher_gemm_lib PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device)
|
||||
endif()
|
||||
|
||||
# =============================================================================
|
||||
# Convolution ctypes Library (supports forward + bwd_data)
|
||||
# =============================================================================
|
||||
|
||||
# Look for forward kernels
|
||||
file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp")
|
||||
# Look for backward data kernels
|
||||
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp")
|
||||
# Fallback: any conv kernel (for backwards compatibility)
|
||||
file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp")
|
||||
|
||||
add_library(dispatcher_conv_lib SHARED conv_ctypes_lib.cpp)
|
||||
target_include_directories(dispatcher_conv_lib PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
target_link_libraries(dispatcher_conv_lib PRIVATE hip::device)
|
||||
set_target_properties(dispatcher_conv_lib PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
)
|
||||
|
||||
# Add forward kernel if available
|
||||
if(CONV_FWD_KERNEL_HEADERS)
|
||||
list(GET CONV_FWD_KERNEL_HEADERS 0 CONV_FWD_KERNEL_HEADER)
|
||||
message(STATUS "Found Conv FWD kernel for ctypes lib: ${CONV_FWD_KERNEL_HEADER}")
|
||||
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_FWD_KERNEL_HEADER})
|
||||
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE)
|
||||
elseif(CONV_KERNEL_HEADERS)
|
||||
# Fallback to any conv kernel
|
||||
list(GET CONV_KERNEL_HEADERS 0 CONV_KERNEL_HEADER)
|
||||
message(STATUS "Found Conv kernel for ctypes lib: ${CONV_KERNEL_HEADER}")
|
||||
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_KERNEL_HEADER})
|
||||
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE)
|
||||
else()
|
||||
message(STATUS "No Conv FWD kernel found for ctypes lib - building without kernel")
|
||||
endif()
|
||||
|
||||
# Add backward data kernel if available
|
||||
if(CONV_BWDD_KERNEL_HEADERS)
|
||||
list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER)
|
||||
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}")
|
||||
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER})
|
||||
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE)
|
||||
endif()
|
||||
|
||||
# =============================================================================
|
||||
# Convolution Backward Weight ctypes Library (separate lib for bwd_weight)
|
||||
# =============================================================================
|
||||
|
||||
file(GLOB CONV_BWDW_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*bwd_weight*.hpp")
|
||||
if(CONV_BWDW_KERNEL_HEADERS)
|
||||
list(GET CONV_BWDW_KERNEL_HEADERS 0 CONV_BWDW_KERNEL_HEADER)
|
||||
message(STATUS "Found Conv BwdWeight kernel for ctypes lib: ${CONV_BWDW_KERNEL_HEADER}")
|
||||
|
||||
add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp)
|
||||
target_include_directories(dispatcher_conv_bwdw_lib PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device)
|
||||
target_compile_options(dispatcher_conv_bwdw_lib PRIVATE
|
||||
-include ${CONV_BWDW_KERNEL_HEADER}
|
||||
)
|
||||
target_compile_definitions(dispatcher_conv_bwdw_lib PRIVATE CONV_BWD_WEIGHT_AVAILABLE)
|
||||
set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
)
|
||||
else()
|
||||
message(STATUS "No Conv BwdWeight kernel found for ctypes lib - building without kernel")
|
||||
add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp)
|
||||
target_include_directories(dispatcher_conv_bwdw_lib PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device)
|
||||
set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
CXX_STANDARD 17
|
||||
)
|
||||
endif()
|
||||
|
||||
# =============================================================================
|
||||
# GPU Helper Executable
|
||||
# =============================================================================
|
||||
|
||||
if(GEMM_KERNEL_HEADERS)
|
||||
add_executable(gpu_helper gpu_helper.cpp)
|
||||
|
||||
target_include_directories(gpu_helper PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/dispatcher/include
|
||||
)
|
||||
|
||||
target_link_libraries(gpu_helper PRIVATE
|
||||
hip::device
|
||||
)
|
||||
|
||||
target_compile_options(gpu_helper PRIVATE
|
||||
-include ${GEMM_KERNEL_HEADER}
|
||||
)
|
||||
|
||||
set_target_properties(gpu_helper PROPERTIES
|
||||
CXX_STANDARD 17
|
||||
)
|
||||
endif()
|
||||
|
||||
175
dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp
Normal file
175
dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp
Normal file
@@ -0,0 +1,175 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Convolution Backward Weight Dispatcher ctypes Library
|
||||
*
|
||||
* SEPARATE library for backward weight to avoid template conflicts with
|
||||
* forward/backward_data kernels in the main conv_ctypes_lib.
|
||||
*
|
||||
* Usage from Python:
|
||||
* lib = ctypes.CDLL("libdispatcher_conv_bwdw_lib.so")
|
||||
* lib.conv_bwdw_init()
|
||||
* lib.conv_bwdw_run(...)
|
||||
*/
|
||||
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
// Minimal includes - matching the C++ example
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp" // Must be before grouped_convolution for TileGemmTraits
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
// Global state - minimal, no registry needed for direct launch
|
||||
static bool g_bwdw_initialized = false;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// =============================================================================
|
||||
// Initialization (minimal - just sets flag)
|
||||
// =============================================================================
|
||||
|
||||
int conv_bwdw_init()
|
||||
{
|
||||
g_bwdw_initialized = true;
|
||||
return 0; // Return 0 on success (consistent with other init functions)
|
||||
}
|
||||
|
||||
void conv_bwdw_cleanup() { g_bwdw_initialized = false; }
|
||||
|
||||
// =============================================================================
|
||||
// Problem Structure (same as main library)
|
||||
// =============================================================================
|
||||
|
||||
struct ConvBwdwProblemC
|
||||
{
|
||||
int N, G, C, K;
|
||||
int input_d, input_h, input_w;
|
||||
int filter_z, filter_y, filter_x;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Backward Weight Execution
|
||||
// =============================================================================
|
||||
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
static ck_tile::conv::ConvParam build_conv_param(const ConvBwdwProblemC* prob)
|
||||
{
|
||||
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
|
||||
|
||||
if(is_3d)
|
||||
{
|
||||
return ck_tile::conv::ConvParam{3,
|
||||
prob->G,
|
||||
prob->N,
|
||||
prob->K,
|
||||
prob->C,
|
||||
{prob->filter_z, prob->filter_y, prob->filter_x},
|
||||
{prob->input_d, prob->input_h, prob->input_w},
|
||||
{prob->stride_d, prob->stride_h, prob->stride_w},
|
||||
{prob->dilation_d, prob->dilation_h, prob->dilation_w},
|
||||
{prob->pad_d, prob->pad_h, prob->pad_w},
|
||||
{prob->pad_d, prob->pad_h, prob->pad_w}};
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::conv::ConvParam{2,
|
||||
prob->G,
|
||||
prob->N,
|
||||
prob->K,
|
||||
prob->C,
|
||||
{prob->filter_y, prob->filter_x},
|
||||
{prob->input_h, prob->input_w},
|
||||
{prob->stride_h, prob->stride_w},
|
||||
{prob->dilation_h, prob->dilation_w},
|
||||
{prob->pad_h, prob->pad_w},
|
||||
{prob->pad_h, prob->pad_w}};
|
||||
}
|
||||
}
|
||||
|
||||
static float run_bwd_weight_impl(const void* input_ptr,
|
||||
const void* grad_output_ptr,
|
||||
void* grad_weight_ptr,
|
||||
const ConvBwdwProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
auto conv_param = build_conv_param(prob);
|
||||
|
||||
// Backward weight: A=input, B=grad_output, C=grad_weight
|
||||
ck_tile::GroupedConvBwdWeightHostArgs args(conv_param,
|
||||
input_ptr, // in_ptr = input
|
||||
grad_weight_ptr, // wei_ptr = grad_weight (output)
|
||||
{}, // ds_ptr
|
||||
grad_output_ptr, // out_ptr = grad_output
|
||||
1 // k_batch
|
||||
);
|
||||
|
||||
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
|
||||
|
||||
return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
|
||||
}
|
||||
#endif
|
||||
|
||||
float conv_bwdw_run(const void* input_ptr,
|
||||
const void* grad_output_ptr,
|
||||
void* grad_weight_ptr,
|
||||
const ConvBwdwProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
// Validate all required pointers before kernel launch
|
||||
if(!g_bwdw_initialized || !prob)
|
||||
return -1.0f;
|
||||
if(!input_ptr || !grad_output_ptr || !grad_weight_ptr)
|
||||
return -1.0f; // Null data pointer would cause kernel crash
|
||||
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
|
||||
#else
|
||||
return -1.0f;
|
||||
#endif
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Info
|
||||
// =============================================================================
|
||||
|
||||
const char* conv_bwdw_version() { return "1.0.0"; }
|
||||
|
||||
int conv_bwdw_has_kernels()
|
||||
{
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int conv_bwdw_get_kernel_count()
|
||||
{
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int conv_bwdw_get_kernel_name(int index, char* buffer, int buffer_size)
|
||||
{
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
if(index != 0 || !buffer || buffer_size <= 0)
|
||||
return -1;
|
||||
std::strncpy(buffer, CONV_BWD_WEIGHT_KERNEL_NAME, buffer_size - 1);
|
||||
buffer[buffer_size - 1] = '\0';
|
||||
return 0;
|
||||
#else
|
||||
return -1;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
411
dispatcher/bindings/ctypes/conv_ctypes_lib.cpp
Normal file
411
dispatcher/bindings/ctypes/conv_ctypes_lib.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Convolution Dispatcher ctypes Library
|
||||
*
|
||||
* Provides C API for Python ctypes integration.
|
||||
* Supports forward convolution. Backward operations require additional headers.
|
||||
*
|
||||
* REQUIRED: Forward kernel header must be force-included via -include flag.
|
||||
* OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE
|
||||
*
|
||||
* Usage from Python:
|
||||
* lib = ctypes.CDLL("libdispatcher_conv.so")
|
||||
* lib.conv_dispatcher_init()
|
||||
* lib.conv_dispatcher_run(...)
|
||||
*/
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/dispatcher/conv_utils.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
|
||||
// Global state (using shared_ptr for safe memory management)
|
||||
static std::shared_ptr<ConvRegistry> g_registry = nullptr;
|
||||
static std::shared_ptr<ConvDispatcher> g_dispatcher = nullptr;
|
||||
static std::vector<const ConvKernelInstance*> g_kernels;
|
||||
|
||||
extern "C" {
|
||||
|
||||
// =============================================================================
|
||||
// Initialization
|
||||
// =============================================================================
|
||||
|
||||
int conv_dispatcher_init()
|
||||
{
|
||||
if(g_registry)
|
||||
return 0; // Already initialized
|
||||
|
||||
g_registry = std::make_shared<ConvRegistry>();
|
||||
g_dispatcher = std::make_shared<ConvDispatcher>(g_registry.get());
|
||||
|
||||
// Register kernel configurations using simple ConvKernelSet
|
||||
// (actual kernel launch uses the force-included SelectedConvKernelLauncher)
|
||||
using namespace ck_tile::dispatcher::conv_decl;
|
||||
|
||||
// Forward kernels (required - must be force-included)
|
||||
// Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb
|
||||
ConvKernelSet fwd_set;
|
||||
fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
|
||||
ConvAlgorithm()
|
||||
.tile(128, 128, 64) // tile_m x tile_n x tile_k
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv4")
|
||||
.scheduler("intrawave"),
|
||||
"gfx942");
|
||||
g_registry->register_set(fwd_set, ConvRegistry::Priority::High);
|
||||
|
||||
#ifdef CONV_BWD_DATA_AVAILABLE
|
||||
// Backward data kernels
|
||||
// Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16
|
||||
ConvKernelSet bwd_data_set;
|
||||
bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
|
||||
ConvAlgorithm()
|
||||
.tile(128, 128, 64) // tile_m x tile_n x tile_k
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave"),
|
||||
"gfx942");
|
||||
g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High);
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int conv_dispatcher_cleanup()
|
||||
{
|
||||
// shared_ptr automatically handles cleanup when reset
|
||||
g_dispatcher.reset();
|
||||
g_registry.reset();
|
||||
g_kernels.clear();
|
||||
return 0;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Registry Management
|
||||
// =============================================================================
|
||||
|
||||
int conv_dispatcher_get_kernel_count()
|
||||
{
|
||||
if(!g_registry)
|
||||
return 0;
|
||||
return static_cast<int>(g_registry->size());
|
||||
}
|
||||
|
||||
int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
|
||||
{
|
||||
if(index < 0 || !buffer || buffer_size <= 0)
|
||||
return -1;
|
||||
|
||||
if(!g_registry)
|
||||
return -1;
|
||||
|
||||
// Use registry to get kernel names (they are registered with full names)
|
||||
const auto& kernels = g_registry->all_kernels();
|
||||
if(static_cast<size_t>(index) >= kernels.size())
|
||||
return -1;
|
||||
|
||||
const auto* kernel = kernels[index];
|
||||
std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1);
|
||||
buffer[buffer_size - 1] = '\0';
|
||||
return 0;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Problem Definition
|
||||
// =============================================================================
|
||||
|
||||
struct ConvProblemC
|
||||
{
|
||||
int N, G, C, K;
|
||||
int input_d, input_h, input_w;
|
||||
int filter_z, filter_y, filter_x;
|
||||
int stride_d, stride_h, stride_w;
|
||||
int pad_d, pad_h, pad_w;
|
||||
int dilation_d, dilation_h, dilation_w;
|
||||
int direction; // 0=forward, 1=bwd_data, 2=bwd_weight
|
||||
};
|
||||
|
||||
// =============================================================================
|
||||
// Kernel Selection
|
||||
// =============================================================================
|
||||
|
||||
int conv_dispatcher_is_supported(const ConvProblemC* prob)
|
||||
{
|
||||
if(!g_registry || !prob)
|
||||
return 0;
|
||||
|
||||
ConvProblem problem;
|
||||
problem.N = prob->N;
|
||||
problem.G = prob->G;
|
||||
problem.C = prob->C;
|
||||
problem.K = prob->K;
|
||||
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
|
||||
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
|
||||
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
|
||||
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
|
||||
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
|
||||
problem.op = static_cast<ConvOp>(prob->direction);
|
||||
problem.compute_output_size();
|
||||
|
||||
const auto* kernel = g_dispatcher->select(problem);
|
||||
return kernel ? 1 : 0;
|
||||
}
|
||||
|
||||
int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size)
|
||||
{
|
||||
if(!g_registry || !prob || !kernel_name || buffer_size <= 0)
|
||||
return -1;
|
||||
|
||||
ConvProblem problem;
|
||||
problem.N = prob->N;
|
||||
problem.G = prob->G;
|
||||
problem.C = prob->C;
|
||||
problem.K = prob->K;
|
||||
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
|
||||
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
|
||||
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
|
||||
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
|
||||
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
|
||||
problem.op = static_cast<ConvOp>(prob->direction);
|
||||
problem.compute_output_size();
|
||||
|
||||
const auto* kernel = g_dispatcher->select(problem);
|
||||
if(!kernel)
|
||||
return -1;
|
||||
|
||||
std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1);
|
||||
kernel_name[buffer_size - 1] = '\0';
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Convolution Execution
|
||||
// =============================================================================
|
||||
|
||||
// Helper to build ConvParam
|
||||
static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob)
|
||||
{
|
||||
// Determine if this is 2D or 3D convolution
|
||||
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
|
||||
|
||||
if(is_3d)
|
||||
{
|
||||
// 3D convolution: use all spatial dimensions
|
||||
return ck_tile::conv::ConvParam{3,
|
||||
prob->G,
|
||||
prob->N,
|
||||
prob->K,
|
||||
prob->C,
|
||||
{prob->filter_z, prob->filter_y, prob->filter_x},
|
||||
{prob->input_d, prob->input_h, prob->input_w},
|
||||
{prob->stride_d, prob->stride_h, prob->stride_w},
|
||||
{prob->dilation_d, prob->dilation_h, prob->dilation_w},
|
||||
{prob->pad_d, prob->pad_h, prob->pad_w},
|
||||
{prob->pad_d, prob->pad_h, prob->pad_w}};
|
||||
}
|
||||
else
|
||||
{
|
||||
// 2D convolution: only use H, W dimensions
|
||||
return ck_tile::conv::ConvParam{2,
|
||||
prob->G,
|
||||
prob->N,
|
||||
prob->K,
|
||||
prob->C,
|
||||
{prob->filter_y, prob->filter_x},
|
||||
{prob->input_h, prob->input_w},
|
||||
{prob->stride_h, prob->stride_w},
|
||||
{prob->dilation_h, prob->dilation_w},
|
||||
{prob->pad_h, prob->pad_w},
|
||||
{prob->pad_h, prob->pad_w}};
|
||||
}
|
||||
}
|
||||
|
||||
// Forward convolution (required - kernel header must be force-included)
|
||||
static float run_forward(const void* input_ptr,
|
||||
const void* weight_ptr,
|
||||
void* output_ptr,
|
||||
const ConvProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
auto conv_param = build_conv_param(prob);
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1);
|
||||
|
||||
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
|
||||
|
||||
// SelectedConvKernelLauncher is defined in the force-included forward kernel header
|
||||
return SelectedConvKernelLauncher::launch(args, stream_cfg);
|
||||
}
|
||||
|
||||
#ifdef CONV_BWD_DATA_AVAILABLE
|
||||
// Backward data convolution (optional)
|
||||
// Computes: grad_input = conv_bwd_data(weight, grad_output)
|
||||
//
|
||||
// Parameters:
|
||||
// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT)
|
||||
// weight_ptr: W - frozen weights (const, read-only INPUT)
|
||||
// grad_input_ptr: dX - gradient for input (writable, OUTPUT)
|
||||
static float run_bwd_data(const void* grad_output_ptr,
|
||||
const void* weight_ptr,
|
||||
void* grad_input_ptr,
|
||||
const ConvProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
auto conv_param = build_conv_param(prob);
|
||||
|
||||
// CK Tile API uses tensor POSITION names (from forward pass), not data flow:
|
||||
// in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data)
|
||||
// wei_ptr = weight tensor = weight_ptr (W, const)
|
||||
// out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data)
|
||||
ck_tile::GroupedConvBwdDataHostArgs args(
|
||||
conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1);
|
||||
|
||||
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
|
||||
|
||||
return SelectedConvBwdDataLauncher::launch(args, stream_cfg);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
// Backward weight convolution (optional)
|
||||
// Parameters:
|
||||
// input_ptr: original forward input X (const, read-only)
|
||||
// grad_output_ptr: gradient from next layer dY (const, read-only)
|
||||
// grad_weight_ptr: gradient of weights dW (writable, OUTPUT)
|
||||
static float run_bwd_weight(const void* input_ptr,
|
||||
const void* grad_output_ptr,
|
||||
void* grad_weight_ptr,
|
||||
const ConvProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
auto conv_param = build_conv_param(prob);
|
||||
|
||||
// GroupedConvBwdWeightHostArgs constructor order:
|
||||
// (param, in=X, wei=dW (output), ds, out=dY (input), k_batch)
|
||||
// Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output)
|
||||
ck_tile::GroupedConvBwdWeightHostArgs args(
|
||||
conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1);
|
||||
|
||||
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
|
||||
|
||||
return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Execute convolution based on direction specified in prob
|
||||
*
|
||||
* Parameter mapping varies by direction:
|
||||
* Forward (direction=0):
|
||||
* input_ptr = X (input tensor)
|
||||
* weight_ptr = W (weight tensor)
|
||||
* output_ptr = Y (output buffer)
|
||||
*
|
||||
* Backward Data (direction=1):
|
||||
* input_ptr = dY (grad_output - gradient from next layer)
|
||||
* weight_ptr = W (weight tensor, frozen)
|
||||
* output_ptr = dX (grad_input buffer)
|
||||
*
|
||||
* Backward Weight (direction=2):
|
||||
* input_ptr = X (forward input tensor)
|
||||
* weight_ptr = dY (grad_output - gradient from next layer)
|
||||
* output_ptr = dW (grad_weight buffer)
|
||||
*/
|
||||
float conv_dispatcher_run(const void* input_ptr,
|
||||
const void* weight_ptr,
|
||||
void* output_ptr,
|
||||
const ConvProblemC* prob,
|
||||
void* stream)
|
||||
{
|
||||
// Validate all required pointers before kernel launch
|
||||
if(!g_dispatcher || !prob)
|
||||
return -1.0f;
|
||||
if(!input_ptr || !weight_ptr || !output_ptr)
|
||||
return -1.0f; // Null data pointer would cause kernel crash
|
||||
|
||||
// Build problem for kernel selection
|
||||
ConvProblem problem;
|
||||
problem.N = prob->N;
|
||||
problem.G = prob->G;
|
||||
problem.C = prob->C;
|
||||
problem.K = prob->K;
|
||||
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
|
||||
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
|
||||
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
|
||||
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
|
||||
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
|
||||
problem.op = static_cast<ConvOp>(prob->direction);
|
||||
problem.compute_output_size();
|
||||
|
||||
// Select kernel
|
||||
const auto* kernel = g_dispatcher->select(problem);
|
||||
if(!kernel)
|
||||
return -1.0f;
|
||||
|
||||
// Dispatch based on direction
|
||||
switch(prob->direction)
|
||||
{
|
||||
case 0: // Forward (always available)
|
||||
return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream);
|
||||
|
||||
#ifdef CONV_BWD_DATA_AVAILABLE
|
||||
case 1: // Backward data
|
||||
// Convention: caller passes (grad_output, weight, grad_input_buffer)
|
||||
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
|
||||
// run_bwd_data expects: (grad_output, weight, grad_input)
|
||||
return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream);
|
||||
#endif
|
||||
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
case 2: // Backward weight
|
||||
// Convention: caller passes (input, grad_output, grad_weight_buffer)
|
||||
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
|
||||
// run_bwd_weight expects: (input, grad_output, grad_weight)
|
||||
return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream);
|
||||
#endif
|
||||
|
||||
default: return -1.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Info
|
||||
// =============================================================================
|
||||
|
||||
const char* conv_dispatcher_version() { return "1.0.0"; }
|
||||
|
||||
int conv_dispatcher_has_kernels()
|
||||
{
|
||||
return 1; // Forward kernel is required
|
||||
}
|
||||
|
||||
int conv_dispatcher_has_bwd_data()
|
||||
{
|
||||
#ifdef CONV_BWD_DATA_AVAILABLE
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
int conv_dispatcher_has_bwd_weight()
|
||||
{
|
||||
#ifdef CONV_BWD_WEIGHT_AVAILABLE
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
401
dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp
Normal file
401
dispatcher/bindings/ctypes/gemm_ctypes_lib.cpp
Normal file
@@ -0,0 +1,401 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* GEMM Dispatcher ctypes Library
|
||||
*
|
||||
* Provides C API for Python ctypes integration.
|
||||
* Kernel header included via -include at compile time.
|
||||
*
|
||||
* Usage from Python:
|
||||
* lib = ctypes.CDLL("libdispatcher_gemm.so")
|
||||
* lib.dispatcher_init()
|
||||
* lib.dispatcher_run_gemm(...)
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/dispatcher/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp"
|
||||
|
||||
// Kernel header included via -include compiler flag
|
||||
// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME
|
||||
|
||||
// GPU architecture - can be overridden via -DGFX_ARCH="gfx90a" at compile time
|
||||
#ifndef GFX_ARCH
|
||||
#define GFX_ARCH "gfx942"
|
||||
#endif
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::backends;
|
||||
using Priority = ck_tile::dispatcher::Registry::Priority;
|
||||
|
||||
// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup)
|
||||
static std::shared_ptr<Dispatcher> g_dispatcher = nullptr;
|
||||
static bool g_initialized = false;
|
||||
|
||||
#define HIP_CHECK(call) \
|
||||
{ \
|
||||
hipError_t err = call; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
return -1; \
|
||||
} \
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
/**
|
||||
* Initialize dispatcher with a kernel
|
||||
* Must be called before run_gemm
|
||||
*
|
||||
* Returns: 0 on success, -1 on error
|
||||
*/
|
||||
int dispatcher_initialize()
|
||||
{
|
||||
if(g_initialized)
|
||||
{
|
||||
return 0; // Already initialized
|
||||
}
|
||||
|
||||
// Create kernel key from the force-included kernel header
|
||||
KernelKey key;
|
||||
key.signature.dtype_a = DataType::FP16;
|
||||
key.signature.dtype_b = DataType::FP16;
|
||||
key.signature.dtype_c = DataType::FP16;
|
||||
key.signature.dtype_acc = DataType::FP32;
|
||||
key.signature.layout_a = LayoutTag::RowMajor;
|
||||
key.signature.layout_b = LayoutTag::ColMajor;
|
||||
key.signature.layout_c = LayoutTag::RowMajor;
|
||||
key.signature.transpose_a = false;
|
||||
key.signature.transpose_b = false;
|
||||
key.signature.grouped = false;
|
||||
key.signature.split_k = 1;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
key.signature.structured_sparsity = false;
|
||||
|
||||
key.algorithm.tile_shape = {128, 128, 32};
|
||||
key.algorithm.wave_shape = {2, 2, 1};
|
||||
key.algorithm.warp_tile_shape = {32, 32, 16};
|
||||
key.algorithm.pipeline = Pipeline::CompV4;
|
||||
key.algorithm.scheduler = Scheduler::Intrawave;
|
||||
key.algorithm.epilogue = Epilogue::CShuffle;
|
||||
key.algorithm.block_size = 256;
|
||||
key.algorithm.double_buffer = true;
|
||||
key.algorithm.persistent = false;
|
||||
key.algorithm.preshuffle = false;
|
||||
key.algorithm.transpose_c = false;
|
||||
key.algorithm.num_wave_groups = 1;
|
||||
key.gfx_arch = GFX_ARCH;
|
||||
|
||||
// Register kernel using types from force-included header
|
||||
auto kernel =
|
||||
create_generated_tile_kernel<SelectedKernel, ADataType, BDataType, CDataType, AccDataType>(
|
||||
key, KERNEL_NAME);
|
||||
|
||||
Registry::instance().clear();
|
||||
Registry::instance().register_kernel(kernel, Priority::High);
|
||||
|
||||
// Create dispatcher (using shared_ptr for safe memory management)
|
||||
g_dispatcher = std::make_shared<Dispatcher>();
|
||||
g_initialized = true;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get kernel tile configuration
|
||||
*/
|
||||
int dispatcher_get_kernel_config(int* tile_m,
|
||||
int* tile_n,
|
||||
int* tile_k,
|
||||
int* warp_tile_m,
|
||||
int* warp_tile_n,
|
||||
int* warp_tile_k,
|
||||
int* warp_m,
|
||||
int* warp_n,
|
||||
int* warp_k)
|
||||
{
|
||||
if(!g_initialized)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto kernels = Registry::instance().get_all();
|
||||
if(kernels.empty())
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Get configuration from first kernel
|
||||
auto& key = kernels[0]->get_key();
|
||||
auto& algo = key.algorithm;
|
||||
|
||||
if(tile_m)
|
||||
*tile_m = algo.tile_shape.m;
|
||||
if(tile_n)
|
||||
*tile_n = algo.tile_shape.n;
|
||||
if(tile_k)
|
||||
*tile_k = algo.tile_shape.k;
|
||||
if(warp_tile_m)
|
||||
*warp_tile_m = algo.warp_tile_shape.m;
|
||||
if(warp_tile_n)
|
||||
*warp_tile_n = algo.warp_tile_shape.n;
|
||||
if(warp_tile_k)
|
||||
*warp_tile_k = algo.warp_tile_shape.k;
|
||||
if(warp_m)
|
||||
*warp_m = algo.wave_shape.m;
|
||||
if(warp_n)
|
||||
*warp_n = algo.wave_shape.n;
|
||||
if(warp_k)
|
||||
*warp_k = algo.wave_shape.k;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the selected kernel name for a problem
|
||||
*/
|
||||
int dispatcher_select_kernel(int64_t M, int64_t N, int64_t K, char* name_buffer, int buffer_size)
|
||||
{
|
||||
if(!g_initialized || !name_buffer || buffer_size <= 0)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
Problem problem(M, N, K);
|
||||
auto kernel = g_dispatcher->select_kernel(problem);
|
||||
|
||||
if(!kernel)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::string name = kernel->get_name();
|
||||
strncpy(name_buffer, name.c_str(), buffer_size - 1);
|
||||
name_buffer[buffer_size - 1] = '\0';
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a problem size is supported by available kernels
|
||||
*/
|
||||
int dispatcher_is_supported(int64_t M, int64_t N, int64_t K)
|
||||
{
|
||||
if(!g_initialized)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(M <= 0 || N <= 0 || K <= 0)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
Problem problem(M, N, K);
|
||||
auto kernel = g_dispatcher->select_kernel(problem);
|
||||
return kernel != nullptr ? 1 : 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run GEMM on GPU via dispatcher
|
||||
*/
|
||||
int dispatcher_run_gemm(
|
||||
const void* A, const void* B, void* C, int64_t M, int64_t N, int64_t K, float* time_ms)
|
||||
{
|
||||
if(!g_initialized || !A || !B || !C)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
// First check if any kernel supports this problem
|
||||
Problem problem(M, N, K);
|
||||
auto kernel = g_dispatcher->select_kernel(problem);
|
||||
if(!kernel)
|
||||
{
|
||||
if(time_ms)
|
||||
{
|
||||
*time_ms = -1.0f;
|
||||
}
|
||||
return -2; // No suitable kernel
|
||||
}
|
||||
|
||||
// Cast to correct types (from force-included header)
|
||||
const ADataType* A_host = static_cast<const ADataType*>(A);
|
||||
const BDataType* B_host = static_cast<const BDataType*>(B);
|
||||
CDataType* C_host = static_cast<CDataType*>(C);
|
||||
|
||||
// Allocate GPU memory
|
||||
ADataType* A_dev = nullptr;
|
||||
BDataType* B_dev = nullptr;
|
||||
CDataType* C_dev = nullptr;
|
||||
|
||||
auto cleanup_gpu_mem = [&]() {
|
||||
if(A_dev)
|
||||
(void)hipFree(A_dev);
|
||||
if(B_dev)
|
||||
(void)hipFree(B_dev);
|
||||
if(C_dev)
|
||||
(void)hipFree(C_dev);
|
||||
};
|
||||
|
||||
if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Copy input data to GPU
|
||||
if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Run GEMM via dispatcher
|
||||
float exec_time;
|
||||
try
|
||||
{
|
||||
exec_time = g_dispatcher->run(A_dev, B_dev, C_dev, problem);
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Copy result back to host
|
||||
if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess)
|
||||
{
|
||||
cleanup_gpu_mem();
|
||||
return -1;
|
||||
}
|
||||
|
||||
if(time_ms)
|
||||
{
|
||||
*time_ms = exec_time;
|
||||
}
|
||||
|
||||
cleanup_gpu_mem();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get kernel information
|
||||
*/
|
||||
const char* dispatcher_get_kernel_name() { return KERNEL_NAME; }
|
||||
|
||||
/**
|
||||
* Initialize dispatcher (alias)
|
||||
*/
|
||||
int dispatcher_init() { return dispatcher_initialize(); }
|
||||
|
||||
/**
|
||||
* Get the number of registered kernels
|
||||
*/
|
||||
int dispatcher_get_kernel_count() { return static_cast<int>(Registry::instance().size()); }
|
||||
|
||||
/**
|
||||
* Export registry to JSON string
|
||||
*/
|
||||
static std::string g_json_buffer;
|
||||
|
||||
const char* dispatcher_export_registry_json()
|
||||
{
|
||||
auto& registry = Registry::instance();
|
||||
|
||||
std::ostringstream json;
|
||||
json << "{\n";
|
||||
json << " \"metadata\": {\n";
|
||||
json << " \"timestamp\": \"" << __DATE__ << " " << __TIME__ << "\",\n";
|
||||
json << " \"total_kernels\": " << registry.size() << ",\n";
|
||||
json << " \"export_version\": \"1.0\",\n";
|
||||
json << " \"dispatcher_version\": \"1.0.0\"\n";
|
||||
json << " },\n";
|
||||
json << " \"statistics\": {\n";
|
||||
json << " \"by_datatype\": {},\n";
|
||||
json << " \"by_pipeline\": {},\n";
|
||||
json << " \"by_scheduler\": {}\n";
|
||||
json << " },\n";
|
||||
json << " \"kernels\": [\n";
|
||||
|
||||
auto kernels = registry.get_all();
|
||||
for(size_t i = 0; i < kernels.size(); ++i)
|
||||
{
|
||||
auto& kernel = kernels[i];
|
||||
auto& key = kernel->get_key();
|
||||
auto& algo = key.algorithm;
|
||||
std::string name = kernel->get_name();
|
||||
|
||||
json << " {\n";
|
||||
json << " \"identifier\": \"" << key.encode_identifier() << "\",\n";
|
||||
json << " \"name\": \"" << name << "\",\n";
|
||||
json << " \"algorithm\": {\n";
|
||||
json << " \"tile_shape\": {\"m\": " << algo.tile_shape.m
|
||||
<< ", \"n\": " << algo.tile_shape.n << ", \"k\": " << algo.tile_shape.k << "},\n";
|
||||
json << " \"wave_shape\": {\"m\": " << unsigned(algo.wave_shape.m)
|
||||
<< ", \"n\": " << unsigned(algo.wave_shape.n)
|
||||
<< ", \"k\": " << unsigned(algo.wave_shape.k) << "},\n";
|
||||
json << " \"warp_tile_shape\": {\"m\": " << unsigned(algo.warp_tile_shape.m)
|
||||
<< ", \"n\": " << unsigned(algo.warp_tile_shape.n)
|
||||
<< ", \"k\": " << unsigned(algo.warp_tile_shape.k) << "},\n";
|
||||
json << " \"block_size\": " << algo.block_size << ",\n";
|
||||
json << " \"persistent\": " << (algo.persistent ? "true" : "false") << ",\n";
|
||||
json << " \"double_buffer\": " << (algo.double_buffer ? "true" : "false") << ",\n";
|
||||
json << " \"preshuffle\": " << (algo.preshuffle ? "true" : "false") << ",\n";
|
||||
json << " \"transpose_c\": " << (algo.transpose_c ? "true" : "false") << "\n";
|
||||
json << " }\n";
|
||||
json << " }";
|
||||
if(i < kernels.size() - 1)
|
||||
{
|
||||
json << ",";
|
||||
}
|
||||
json << "\n";
|
||||
}
|
||||
|
||||
json << " ]\n";
|
||||
json << "}\n";
|
||||
|
||||
g_json_buffer = json.str();
|
||||
return g_json_buffer.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup dispatcher resources
|
||||
*/
|
||||
void dispatcher_cleanup()
|
||||
{
|
||||
g_dispatcher.reset();
|
||||
g_initialized = false;
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
206
dispatcher/bindings/ctypes/gpu_helper.cpp
Normal file
206
dispatcher/bindings/ctypes/gpu_helper.cpp
Normal file
@@ -0,0 +1,206 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* GPU Helper - C++ executable for GPU GEMM execution
|
||||
*
|
||||
* A CLI tool for Python to execute GPU GEMM with generated kernels.
|
||||
* Usage: gpu_helper <M> <N> <K> [--validate]
|
||||
*
|
||||
* Kernel header included via -include flag at compile time.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cstring>
|
||||
#include <cmath>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/dispatcher/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp"
|
||||
|
||||
// Kernel header included via -include compiler flag
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::backends;
|
||||
using Priority = ck_tile::dispatcher::Registry::Priority;
|
||||
|
||||
#define HIP_CHECK(call) \
|
||||
{ \
|
||||
hipError_t err = call; \
|
||||
if(err != hipSuccess) \
|
||||
{ \
|
||||
std::cerr << "HIP_ERROR: " << hipGetErrorString(err) << "\n"; \
|
||||
exit(1); \
|
||||
} \
|
||||
}
|
||||
|
||||
// CPU reference GEMM (for validation)
|
||||
template <typename T>
|
||||
void cpu_gemm(
|
||||
const std::vector<T>& A, const std::vector<T>& B, std::vector<T>& C, int M, int N, int K)
|
||||
{
|
||||
for(int m = 0; m < M; m++)
|
||||
{
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
float acc = 0.0f;
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
// A: RowMajor, B: ColumnMajor
|
||||
acc += float(A[m * K + k]) * float(B[k + n * K]);
|
||||
}
|
||||
C[m * N + n] = T(acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
// Parse arguments
|
||||
if(argc < 4)
|
||||
{
|
||||
std::cerr << "Usage: " << argv[0] << " <M> <N> <K> [--validate]\n";
|
||||
std::cerr << "\nOptions:\n";
|
||||
std::cerr << " M, N, K : Problem dimensions\n";
|
||||
std::cerr << " --validate : Compare GPU results with CPU reference\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
int M = std::atoi(argv[1]);
|
||||
int N = std::atoi(argv[2]);
|
||||
int K = std::atoi(argv[3]);
|
||||
bool validate = (argc > 4 && std::string(argv[4]) == "--validate");
|
||||
|
||||
// Output in JSON-like format for easy Python parsing
|
||||
std::cout << "{" << std::endl;
|
||||
std::cout << " \"problem\": {\"M\": " << M << ", \"N\": " << N << ", \"K\": " << K << "},"
|
||||
<< std::endl;
|
||||
std::cout << " \"kernel\": \"" << KERNEL_NAME << "\"," << std::endl;
|
||||
|
||||
// Register kernel
|
||||
KernelKey key;
|
||||
key.signature.dtype_a = DataType::FP16;
|
||||
key.signature.dtype_b = DataType::FP16;
|
||||
key.signature.dtype_c = DataType::FP16;
|
||||
key.signature.dtype_acc = DataType::FP32;
|
||||
key.signature.layout_a = LayoutTag::RowMajor;
|
||||
key.signature.layout_b = LayoutTag::ColMajor;
|
||||
key.signature.layout_c = LayoutTag::RowMajor;
|
||||
key.signature.transpose_a = false;
|
||||
key.signature.transpose_b = false;
|
||||
key.signature.grouped = false;
|
||||
key.signature.split_k = 1;
|
||||
key.signature.elementwise_op = "PassThrough";
|
||||
key.signature.num_d_tensors = 0;
|
||||
key.signature.structured_sparsity = false;
|
||||
|
||||
key.algorithm.tile_shape = {128, 128, 32};
|
||||
key.algorithm.wave_shape = {2, 2, 1};
|
||||
key.algorithm.warp_tile_shape = {32, 32, 16};
|
||||
key.algorithm.pipeline = Pipeline::CompV4;
|
||||
key.algorithm.scheduler = Scheduler::Intrawave;
|
||||
key.algorithm.epilogue = Epilogue::CShuffle;
|
||||
key.algorithm.block_size = 256;
|
||||
key.algorithm.double_buffer = true;
|
||||
key.algorithm.persistent = false;
|
||||
key.algorithm.preshuffle = false;
|
||||
key.algorithm.transpose_c = false;
|
||||
key.algorithm.num_wave_groups = 1;
|
||||
key.gfx_arch = "gfx942";
|
||||
|
||||
auto kernel =
|
||||
create_generated_tile_kernel<SelectedKernel, ADataType, BDataType, CDataType, AccDataType>(
|
||||
key, KERNEL_NAME);
|
||||
|
||||
Registry::instance().clear();
|
||||
Registry::instance().register_kernel(kernel, Priority::High);
|
||||
|
||||
Dispatcher dispatcher;
|
||||
Problem problem(M, N, K);
|
||||
|
||||
auto selected = dispatcher.select_kernel(problem);
|
||||
if(!selected)
|
||||
{
|
||||
std::cout << " \"error\": \"No kernel selected\"" << std::endl;
|
||||
std::cout << "}" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << " \"selected_kernel\": \"" << selected->get_name() << "\"," << std::endl;
|
||||
|
||||
// Prepare data: A=1, B=1, so C should be K
|
||||
std::vector<ADataType> A_host(M * K, ADataType(1.0f));
|
||||
std::vector<BDataType> B_host(K * N, BDataType(1.0f));
|
||||
std::vector<CDataType> C_gpu(M * N);
|
||||
|
||||
// GPU execution
|
||||
ADataType *A_dev, *B_dev;
|
||||
CDataType* C_dev;
|
||||
|
||||
HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType)));
|
||||
HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType)));
|
||||
HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType)));
|
||||
|
||||
HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice));
|
||||
HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice));
|
||||
HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType)));
|
||||
|
||||
float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem);
|
||||
|
||||
HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
// Calculate performance
|
||||
double flops = 2.0 * M * N * K;
|
||||
double tflops = (flops / (gpu_time * 1e-3)) / 1e12;
|
||||
|
||||
std::cout << " \"execution\": {" << std::endl;
|
||||
std::cout << " \"time_ms\": " << gpu_time << "," << std::endl;
|
||||
std::cout << " \"tflops\": " << tflops << "," << std::endl;
|
||||
std::cout << " \"flops\": " << (long long)flops << std::endl;
|
||||
std::cout << " }," << std::endl;
|
||||
|
||||
// Validation
|
||||
if(validate)
|
||||
{
|
||||
std::vector<CDataType> C_cpu(M * N);
|
||||
cpu_gemm(A_host, B_host, C_cpu, M, N, K);
|
||||
|
||||
int correct = 0;
|
||||
float max_error = 0.0f;
|
||||
|
||||
for(int i = 0; i < M * N; i++)
|
||||
{
|
||||
float gpu_val = float(C_gpu[i]);
|
||||
float cpu_val = float(C_cpu[i]);
|
||||
float error = std::abs(gpu_val - cpu_val) / (std::abs(cpu_val) + 1e-5f);
|
||||
|
||||
max_error = std::max(max_error, error);
|
||||
|
||||
if(error < 0.02f)
|
||||
{
|
||||
correct++;
|
||||
}
|
||||
}
|
||||
|
||||
float accuracy = 100.0f * correct / (M * N);
|
||||
|
||||
std::cout << " \"validation\": {" << std::endl;
|
||||
std::cout << " \"accuracy\": " << accuracy << "," << std::endl;
|
||||
std::cout << " \"max_error\": " << max_error << "," << std::endl;
|
||||
std::cout << " \"correct_elements\": " << correct << "," << std::endl;
|
||||
std::cout << " \"total_elements\": " << M * N << std::endl;
|
||||
std::cout << " }," << std::endl;
|
||||
}
|
||||
|
||||
std::cout << " \"status\": \"success\"" << std::endl;
|
||||
std::cout << "}" << std::endl;
|
||||
|
||||
// Cleanup
|
||||
HIP_CHECK(hipFree(A_dev));
|
||||
HIP_CHECK(hipFree(B_dev));
|
||||
HIP_CHECK(hipFree(C_dev));
|
||||
|
||||
return 0;
|
||||
}
|
||||
197
dispatcher/codegen/ADDING_NEW_GPU.md
Normal file
197
dispatcher/codegen/ADDING_NEW_GPU.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# Adding New GPU Architecture Support
|
||||
|
||||
Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatcher.
|
||||
|
||||
> **See also:** [Main Dispatcher README](../README.md) | [Codegen README](README.md)
|
||||
|
||||
## Overview
|
||||
|
||||
The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications:
|
||||
|
||||
```
|
||||
arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python)
|
||||
→ arch_specs_generated.hpp (C++)
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Edit arch_specs.json
|
||||
# 2. Run generator
|
||||
python generate_arch_specs.py
|
||||
# 3. Rebuild
|
||||
cd ../build && cmake --build . -j8
|
||||
# 4. Test
|
||||
ctest
|
||||
```
|
||||
|
||||
## Step-by-Step Guide
|
||||
|
||||
### Step 1: Edit arch_specs.json
|
||||
|
||||
Add new architecture under `"architectures"`:
|
||||
|
||||
```json
|
||||
{
|
||||
"architectures": {
|
||||
"gfx1100": {
|
||||
"family": "rdna3",
|
||||
"description": "AMD Radeon RX 7000 series (RDNA3)",
|
||||
"warp_size": 32,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[2, 4, 1],
|
||||
[4, 2, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]],
|
||||
"bf16_bf16_bf16": [[16, 16, 16], [32, 32, 16]]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Configuration Fields
|
||||
|
||||
| Field | Description | Example |
|
||||
|-------|-------------|---------|
|
||||
| `family` | GPU family | `"cdna3"`, `"rdna4"` |
|
||||
| `description` | Human-readable name | `"AMD Instinct MI300"` |
|
||||
| `warp_size` | Wave/warp size | `64` (CDNA), `32` (RDNA) |
|
||||
| `lds_capacity_kb` | LDS memory in KB | `64` |
|
||||
| `warp_configs` | Valid `[warp_m, warp_n, warp_k]` | `[[2,2,1], [4,4,1]]` |
|
||||
| `warp_tile_combos` | Warp tiles per dtype | See below |
|
||||
|
||||
### Step 3: Warp Tile Combinations
|
||||
|
||||
Map data type combinations to valid warp tile sizes:
|
||||
|
||||
```json
|
||||
"warp_tile_combos": {
|
||||
"fp16_fp16_fp16": [[32, 32, 8], [16, 16, 16], [32, 32, 16]],
|
||||
"bf16_bf16_bf16": [[32, 32, 8], [16, 16, 16]],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
|
||||
}
|
||||
```
|
||||
|
||||
Key format: `{A_dtype}_{B_dtype}_{C_dtype}`
|
||||
|
||||
### Step 4: Run Generator
|
||||
|
||||
```bash
|
||||
cd dispatcher/codegen
|
||||
python generate_arch_specs.py
|
||||
```
|
||||
|
||||
This generates:
|
||||
- `arch_specs_generated.py` (Python module)
|
||||
- `../include/ck_tile/dispatcher/arch_specs_generated.hpp` (C++ header)
|
||||
|
||||
### Step 5: Rebuild and Test
|
||||
|
||||
```bash
|
||||
cd ../build
|
||||
cmake --build . -j8
|
||||
ctest --output-on-failure
|
||||
```
|
||||
|
||||
### Step 6: Verify
|
||||
|
||||
```python
|
||||
from arch_filter import ArchFilter
|
||||
|
||||
filter = ArchFilter("gfx1100")
|
||||
is_valid = filter.is_kernel_valid(
|
||||
datatype_a="fp16", datatype_b="fp16", datatype_c="fp16",
|
||||
tile_m=128, tile_n=128, tile_k=32,
|
||||
warp_m=2, warp_n=2, warp_k=1,
|
||||
warp_tile_m=16, warp_tile_n=16, warp_tile_k=16
|
||||
)
|
||||
print(f"Valid: {is_valid}")
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
### Supported Data Types
|
||||
|
||||
| Key | Description |
|
||||
|-----|-------------|
|
||||
| `fp16` | Half precision (16-bit) |
|
||||
| `bf16` | Brain float 16 |
|
||||
| `fp32` | Single precision (32-bit) |
|
||||
| `fp64` | Double precision (64-bit) |
|
||||
| `fp8` | 8-bit float (E4M3) |
|
||||
| `bf8` | 8-bit brain float (E5M2) |
|
||||
| `int8` | 8-bit integer |
|
||||
| `int4` | 4-bit integer |
|
||||
|
||||
### GPU Families
|
||||
|
||||
| Family | Description |
|
||||
|--------|-------------|
|
||||
| `cdna2` | MI200 series (gfx90a) |
|
||||
| `cdna3` | MI300 series (gfx942) |
|
||||
| `cdna4` | MI350 series (gfx950) |
|
||||
| `rdna3` | RX 7000 series (gfx1100) |
|
||||
| `rdna4` | RX 9000 series (gfx1201) |
|
||||
|
||||
### Pipeline LDS Limits
|
||||
|
||||
| Pipeline | LDS Limit |
|
||||
|----------|-----------|
|
||||
| `compv4` | 32 KB |
|
||||
| `preshufflev2` | 32 KB |
|
||||
| `default` | 64 KB |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Unknown GPU architecture"
|
||||
|
||||
1. Check architecture key matches exactly (e.g., `"gfx942"` not `"GFX942"`)
|
||||
2. Verify you ran `generate_arch_specs.py`
|
||||
3. Rebuild C++ code
|
||||
|
||||
### Kernels being rejected
|
||||
|
||||
```python
|
||||
from arch_filter import ArchFilter, KernelConfig
|
||||
|
||||
filter = ArchFilter("gfx942")
|
||||
result = filter.validate_kernel(config)
|
||||
print(f"Valid: {result.valid}")
|
||||
for error in result.errors:
|
||||
print(f" Error: {error}")
|
||||
```
|
||||
|
||||
### Missing warp tile combination
|
||||
|
||||
1. Check `warp_tile_combos` in `arch_specs.json`
|
||||
2. Ensure `[warp_tile_m, warp_tile_n, warp_tile_k]` is in the list
|
||||
3. Verify data type key format
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
codegen/
|
||||
├── arch_specs.json # Single source of truth (EDIT THIS)
|
||||
├── generate_arch_specs.py # Generator script
|
||||
├── arch_specs_generated.py # Generated Python module
|
||||
└── ADDING_NEW_GPU.md # This file
|
||||
|
||||
include/ck_tile/dispatcher/
|
||||
├── arch_specs_generated.hpp # Generated C++ header
|
||||
└── arch_filter.hpp # C++ filter
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Test thoroughly** - Run all tests after adding a new GPU
|
||||
2. **Start minimal** - Add only validated configurations
|
||||
3. **Document sources** - Note where warp tile combinations came from
|
||||
4. **Keep in sync** - If using tile_engine, keep both updated
|
||||
|
||||
---
|
||||
|
||||
> **More info:** See [../README.md](../README.md) for full documentation.
|
||||
125
dispatcher/codegen/CMakeLists.txt
Normal file
125
dispatcher/codegen/CMakeLists.txt
Normal file
@@ -0,0 +1,125 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# CK Tile GEMM Unified Code Generator
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
# Find Python
|
||||
find_package(Python3 COMPONENTS Interpreter REQUIRED)
|
||||
|
||||
# Configuration
|
||||
set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/unified_gemm_codegen.py")
|
||||
set(CODEGEN_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json")
|
||||
set(CODEGEN_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm")
|
||||
|
||||
# Configurable options
|
||||
set(CK_TILE_GEMM_DATATYPE "fp16" CACHE STRING "GEMM data type (fp16, bf16, fp32, fp8, bf8, int8)")
|
||||
set(CK_TILE_GEMM_LAYOUT "rcr" CACHE STRING "GEMM layout (rcr, rrr, crr, ccr)")
|
||||
set(CK_TILE_GEMM_VARIANTS "standard" CACHE STRING "GEMM variants (standard, preshuffle, multi_d)")
|
||||
set(CK_TILE_GEMM_GPU_TARGET "gfx942" CACHE STRING "Target GPU architecture")
|
||||
set(CK_TILE_GEMM_PARALLEL ON CACHE BOOL "Enable parallel generation")
|
||||
|
||||
# Custom target to run code generation
|
||||
add_custom_target(generate_tile_gemm_kernels
|
||||
COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT}
|
||||
--output-dir ${CODEGEN_OUTPUT_DIR}
|
||||
--datatype ${CK_TILE_GEMM_DATATYPE}
|
||||
--layout ${CK_TILE_GEMM_LAYOUT}
|
||||
--gpu-target ${CK_TILE_GEMM_GPU_TARGET}
|
||||
--config ${CODEGEN_CONFIG}
|
||||
--variants ${CK_TILE_GEMM_VARIANTS}
|
||||
$<$<NOT:$<BOOL:${CK_TILE_GEMM_PARALLEL}>>:--no-parallel>
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
COMMENT "Generating CK Tile GEMM kernels and dispatcher wrappers..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Create output directory
|
||||
file(MAKE_DIRECTORY ${CODEGEN_OUTPUT_DIR})
|
||||
|
||||
# Add generated headers to include path
|
||||
include_directories(${CODEGEN_OUTPUT_DIR})
|
||||
|
||||
# Installation
|
||||
install(FILES
|
||||
${CODEGEN_SCRIPT}
|
||||
${CODEGEN_CONFIG}
|
||||
README.md
|
||||
DESTINATION share/ck_tile/codegen
|
||||
)
|
||||
|
||||
# Helper function for projects to generate kernels
|
||||
function(ck_tile_generate_gemm_kernels)
|
||||
set(options PARALLEL)
|
||||
set(oneValueArgs OUTPUT_DIR DATATYPE LAYOUT GPU_TARGET CONFIG)
|
||||
set(multiValueArgs VARIANTS)
|
||||
cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
# Set defaults
|
||||
if(NOT ARG_OUTPUT_DIR)
|
||||
set(ARG_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm")
|
||||
endif()
|
||||
if(NOT ARG_DATATYPE)
|
||||
set(ARG_DATATYPE "fp16")
|
||||
endif()
|
||||
if(NOT ARG_LAYOUT)
|
||||
set(ARG_LAYOUT "rcr")
|
||||
endif()
|
||||
if(NOT ARG_GPU_TARGET)
|
||||
set(ARG_GPU_TARGET "gfx942")
|
||||
endif()
|
||||
if(NOT ARG_CONFIG)
|
||||
set(ARG_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json")
|
||||
endif()
|
||||
if(NOT ARG_VARIANTS)
|
||||
set(ARG_VARIANTS "standard")
|
||||
endif()
|
||||
|
||||
# Build command
|
||||
set(CMD ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT}
|
||||
--output-dir ${ARG_OUTPUT_DIR}
|
||||
--datatype ${ARG_DATATYPE}
|
||||
--layout ${ARG_LAYOUT}
|
||||
--gpu-target ${ARG_GPU_TARGET}
|
||||
--config ${ARG_CONFIG}
|
||||
--variants ${ARG_VARIANTS}
|
||||
)
|
||||
|
||||
if(NOT ARG_PARALLEL)
|
||||
list(APPEND CMD --no-parallel)
|
||||
endif()
|
||||
|
||||
# Execute
|
||||
execute_process(
|
||||
COMMAND ${CMD}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
RESULT_VARIABLE RESULT
|
||||
OUTPUT_VARIABLE OUTPUT
|
||||
ERROR_VARIABLE ERROR
|
||||
)
|
||||
|
||||
if(NOT RESULT EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to generate GEMM kernels:\n${ERROR}")
|
||||
else()
|
||||
message(STATUS "Generated GEMM kernels: ${OUTPUT}")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Example usage documentation
|
||||
message(STATUS "CK Tile GEMM Code Generator configured")
|
||||
message(STATUS " Script: ${CODEGEN_SCRIPT}")
|
||||
message(STATUS " Config: ${CODEGEN_CONFIG}")
|
||||
message(STATUS " Output: ${CODEGEN_OUTPUT_DIR}")
|
||||
message(STATUS "")
|
||||
message(STATUS "To generate kernels:")
|
||||
message(STATUS " cmake --build . --target generate_tile_gemm_kernels")
|
||||
message(STATUS "")
|
||||
message(STATUS "Or use CMake function:")
|
||||
message(STATUS " ck_tile_generate_gemm_kernels(")
|
||||
message(STATUS " OUTPUT_DIR ./generated")
|
||||
message(STATUS " DATATYPE fp16")
|
||||
message(STATUS " LAYOUT rcr")
|
||||
message(STATUS " VARIANTS standard preshuffle multi_d")
|
||||
message(STATUS " PARALLEL")
|
||||
message(STATUS " )")
|
||||
123
dispatcher/codegen/README.md
Normal file
123
dispatcher/codegen/README.md
Normal file
@@ -0,0 +1,123 @@
|
||||
# CK Tile GEMM Unified Code Generator
|
||||
|
||||
Single source of truth for all GEMM kernel generation.
|
||||
|
||||
> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
cd dispatcher/codegen
|
||||
|
||||
# Generate standard FP16 kernels
|
||||
python3 unified_gemm_codegen.py \
|
||||
--output-dir ../build/generated_kernels \
|
||||
--datatype fp16 \
|
||||
--layout rcr \
|
||||
--variants standard
|
||||
|
||||
# Generate all variants
|
||||
python3 unified_gemm_codegen.py \
|
||||
--output-dir ../build/generated_kernels \
|
||||
--variants standard preshuffle multi_d
|
||||
```
|
||||
|
||||
## Using from Python
|
||||
|
||||
```python
|
||||
from ctypes_utils import CodegenRunner, KernelConfig
|
||||
|
||||
# Generate from specific config
|
||||
config = KernelConfig(tile_m=256, tile_n=256, tile_k=64)
|
||||
codegen = CodegenRunner()
|
||||
result = codegen.generate_from_config(config)
|
||||
|
||||
# Generate variant
|
||||
result = codegen.generate("preshuffle")
|
||||
|
||||
# Generate all
|
||||
results = codegen.generate_all()
|
||||
```
|
||||
|
||||
## Command Line Options
|
||||
|
||||
| Option | Values | Description |
|
||||
|--------|--------|-------------|
|
||||
| `--output-dir` | path | Output directory |
|
||||
| `--datatype` | `fp16`, `bf16`, `fp32`, `int8` | Data type |
|
||||
| `--layout` | `rcr`, `rrr`, `crr`, `ccr` | Matrix layouts |
|
||||
| `--gpu-target` | `gfx942`, `gfx90a`, `gfx950` | Target GPU |
|
||||
| `--variants` | `standard`, `preshuffle`, `multi_d` | Kernel variants |
|
||||
| `--preselected` | `fp16_rcr_essential`, etc. | Predefined kernel set |
|
||||
|
||||
### Layout Notation
|
||||
|
||||
- `R` = Row-major, `C` = Column-major
|
||||
- Order: A, B, C (e.g., `rcr` = A row, B col, C row)
|
||||
|
||||
## Variants
|
||||
|
||||
### Standard
|
||||
Basic GEMM: `C = A × B`
|
||||
|
||||
### PreShuffle
|
||||
Optimized weight access with LDS pre-shuffling. Best for large matrices.
|
||||
|
||||
### Multi-D
|
||||
Element-wise fusion: `C = op(A × B + D0 + D1 + ...)`
|
||||
|
||||
Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
|
||||
|
||||
## Output Structure
|
||||
|
||||
```
|
||||
generated_kernels/
|
||||
├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp
|
||||
├── gemm_fp16_rcr_compv4_..._preshuffle.hpp
|
||||
├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### arch_specs.json
|
||||
|
||||
GPU architecture specifications (single source of truth):
|
||||
|
||||
```json
|
||||
{
|
||||
"architectures": {
|
||||
"gfx942": {
|
||||
"family": "cdna3",
|
||||
"warp_size": 64,
|
||||
"warp_configs": [[2, 2, 1], [4, 4, 1]],
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### preselected_kernels.py
|
||||
|
||||
Curated kernel sets for common use cases.
|
||||
|
||||
## Adding New GPU Support
|
||||
|
||||
See [ADDING_NEW_GPU.md](ADDING_NEW_GPU.md) for complete guide.
|
||||
|
||||
Quick steps:
|
||||
1. Edit `arch_specs.json`
|
||||
2. Run `python generate_arch_specs.py`
|
||||
3. Rebuild
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| "Arguments not supported" | Check tile config validity |
|
||||
| Missing element-wise op | Check `elementwise_ops.hpp` |
|
||||
| Compilation errors | Verify C++17, include paths |
|
||||
|
||||
---
|
||||
|
||||
> **More info:** See [../README.md](../README.md) for full documentation.
|
||||
1012
dispatcher/codegen/arch_filter.py
Normal file
1012
dispatcher/codegen/arch_filter.py
Normal file
File diff suppressed because it is too large
Load Diff
270
dispatcher/codegen/arch_specs.json
Normal file
270
dispatcher/codegen/arch_specs.json
Normal file
@@ -0,0 +1,270 @@
|
||||
{
|
||||
"_comment": "Single source of truth for GPU architecture specifications. Edit this file to add new GPU support.",
|
||||
"_version": "1.2.0",
|
||||
"_instructions": "See ADDING_NEW_GPU.md for instructions on adding new GPU support.",
|
||||
"_supported_arch_note": "CK Tile supports: GFX9 (gfx908, gfx90a, gfx942, gfx950), GFX10.3 (gfx103x), GFX11 (gfx110x, gfx115x), GFX12 (gfx120x)",
|
||||
|
||||
"architectures": {
|
||||
"gfx908": {
|
||||
"family": "cdna1",
|
||||
"target_family": "gfx9",
|
||||
"architecture": "cdna",
|
||||
"description": "AMD Instinct MI100",
|
||||
"warp_size": 64,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx90a": {
|
||||
"family": "cdna2",
|
||||
"target_family": "gfx9",
|
||||
"architecture": "cdna",
|
||||
"description": "AMD Instinct MI200 series",
|
||||
"warp_size": 64,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx942": {
|
||||
"family": "cdna3",
|
||||
"target_family": "gfx9",
|
||||
"architecture": "cdna",
|
||||
"description": "AMD Instinct MI300 series",
|
||||
"warp_size": 64,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]],
|
||||
"bf8_fp8_fp32": [[32, 32, 16]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx950": {
|
||||
"family": "cdna4",
|
||||
"target_family": "gfx9",
|
||||
"architecture": "cdna",
|
||||
"description": "AMD Instinct MI350 series",
|
||||
"warp_size": 64,
|
||||
"lds_capacity_kb": 160,
|
||||
"warp_configs": [
|
||||
[1, 4, 1],
|
||||
[2, 2, 1],
|
||||
[4, 1, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
"fp8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 128], [32, 32, 64]],
|
||||
"bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
|
||||
"pk_fp4_pk_fp4_fp32": [[16, 16, 128]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx1100": {
|
||||
"family": "rdna3",
|
||||
"target_family": "gfx11",
|
||||
"architecture": "rdna",
|
||||
"description": "AMD Radeon RX 7900 series (RDNA3)",
|
||||
"warp_size": 32,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx1200": {
|
||||
"family": "rdna4",
|
||||
"target_family": "gfx12",
|
||||
"architecture": "rdna",
|
||||
"description": "AMD Radeon RX 9000 series (RDNA4)",
|
||||
"warp_size": 32,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"fp8_fp8_fp32": [[16, 16, 16]],
|
||||
"bf8_bf8_fp32": [[16, 16, 16]],
|
||||
"fp8_bf8_fp32": [[16, 16, 16]],
|
||||
"bf8_fp8_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]]
|
||||
}
|
||||
},
|
||||
|
||||
"gfx1201": {
|
||||
"family": "rdna4",
|
||||
"target_family": "gfx12",
|
||||
"architecture": "rdna",
|
||||
"description": "AMD Radeon RX 9000 series (RDNA4)",
|
||||
"warp_size": 32,
|
||||
"lds_capacity_kb": 64,
|
||||
"warp_configs": [
|
||||
[2, 4, 1],
|
||||
[1, 8, 1],
|
||||
[8, 1, 1],
|
||||
[4, 2, 1]
|
||||
],
|
||||
"warp_tile_combos": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"fp8_fp8_fp32": [[16, 16, 16]],
|
||||
"bf8_bf8_fp32": [[16, 16, 16]],
|
||||
"fp8_bf8_fp32": [[16, 16, 16]],
|
||||
"bf8_fp8_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]]
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
"element_sizes": {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int8": 1,
|
||||
"int4": 0.5,
|
||||
"pk_fp4": 0.5,
|
||||
"int32": 4
|
||||
},
|
||||
|
||||
"datatype_cpp_map": {
|
||||
"_comment": "Maps dtype string to CK Tile C++ type for code generation",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"pk_fp4": "ck_tile::pk_fp4_t",
|
||||
"int32": "ck_tile::int32_t"
|
||||
},
|
||||
|
||||
"dtype_combinations": {
|
||||
"_comment": "All valid (A, B) -> Acc combinations for GEMM from warp_gemm_dispatcher.hpp",
|
||||
"fp32_fp32": {"acc": "fp32", "notes": "Full precision"},
|
||||
"fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"},
|
||||
"bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"},
|
||||
"fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"},
|
||||
"fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"},
|
||||
"bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"},
|
||||
"bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"},
|
||||
"int8_int8": {"acc": "int32", "notes": "Integer GEMM"},
|
||||
"pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}
|
||||
},
|
||||
|
||||
"layout_cpp_map": {
|
||||
"_comment": "Maps layout character to CK Tile C++ type",
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor"
|
||||
},
|
||||
|
||||
"pipeline_lds_limits": {
|
||||
"_comment": "LDS capacity limits in bytes for different pipeline types",
|
||||
"mem": 65536,
|
||||
"compv1": 65536,
|
||||
"compv2": 65536,
|
||||
"compv3": 65536,
|
||||
"compv4": 32768,
|
||||
"compv5": 65536,
|
||||
"preshufflev1": 32768,
|
||||
"preshufflev2": 32768,
|
||||
"default": 65536
|
||||
},
|
||||
|
||||
"unsupported_trait_combos": {
|
||||
"_comment": "Only 'mem' pipeline supports interwave scheduler. All compute pipelines only support intrawave.",
|
||||
"combinations": [
|
||||
["compv3", "cshuffle", "interwave"],
|
||||
["compv3", "default", "interwave"],
|
||||
["compv4", "cshuffle", "interwave"],
|
||||
["compv4", "default", "interwave"],
|
||||
["compv5", "cshuffle", "interwave"],
|
||||
["compv5", "default", "interwave"],
|
||||
["compv6", "cshuffle", "interwave"],
|
||||
["compv6", "default", "interwave"],
|
||||
["comp_async", "cshuffle", "interwave"],
|
||||
["comp_async", "default", "interwave"]
|
||||
]
|
||||
},
|
||||
|
||||
"preshuffle_warp_tile_combos": {
|
||||
"_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])",
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
},
|
||||
|
||||
"preshuffle_pipelines": {
|
||||
"_comment": "Pipelines supported for preshuffle GEMM variant",
|
||||
"supported": ["preshufflev2"]
|
||||
}
|
||||
}
|
||||
358
dispatcher/codegen/arch_specs_generated.py
Normal file
358
dispatcher/codegen/arch_specs_generated.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
|
||||
|
||||
Generated from: arch_specs.json
|
||||
Generated at: 2026-01-05T19:34:01.224422
|
||||
|
||||
To update this file:
|
||||
1. Edit arch_specs.json
|
||||
2. Run: python generate_arch_specs.py
|
||||
|
||||
This module provides architecture-specific configurations for kernel filtering.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
# =============================================================================
|
||||
# Architecture Data (Generated from arch_specs.json)
|
||||
# =============================================================================
|
||||
|
||||
# GPU architecture to family mapping
|
||||
ARCH_FAMILY_MAP: Dict[str, str] = {
|
||||
"gfx908": "cdna1",
|
||||
"gfx90a": "cdna2",
|
||||
"gfx942": "cdna3",
|
||||
"gfx950": "cdna4",
|
||||
"gfx1100": "rdna3",
|
||||
"gfx1200": "rdna4",
|
||||
"gfx1201": "rdna4",
|
||||
}
|
||||
|
||||
# Element size in bytes for each data type
|
||||
ELEMENT_SIZE_MAP: Dict[str, float] = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int8": 1,
|
||||
"int4": 0.5,
|
||||
"pk_fp4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
# Supported warp configurations per architecture [warp_m, warp_n, warp_k]
|
||||
WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {
|
||||
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
|
||||
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
|
||||
}
|
||||
|
||||
# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...]
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
|
||||
"gfx908": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
|
||||
},
|
||||
"gfx90a": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]],
|
||||
"bf8_fp8_fp32": [[32, 32, 16]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"fp8_bf8_fp32": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]],
|
||||
"bf8_bf8_fp32": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
|
||||
"pk_fp4_pk_fp4_fp32": [[16, 16, 128]],
|
||||
},
|
||||
"gfx1100": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]],
|
||||
},
|
||||
"gfx1200": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"fp8_fp8_fp32": [[16, 16, 16]],
|
||||
"bf8_bf8_fp32": [[16, 16, 16]],
|
||||
"fp8_bf8_fp32": [[16, 16, 16]],
|
||||
"bf8_fp8_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]],
|
||||
},
|
||||
"gfx1201": {
|
||||
"fp16_fp16_fp32": [[16, 16, 16]],
|
||||
"bf16_bf16_fp32": [[16, 16, 16]],
|
||||
"fp8_fp8_fp32": [[16, 16, 16]],
|
||||
"bf8_bf8_fp32": [[16, 16, 16]],
|
||||
"fp8_bf8_fp32": [[16, 16, 16]],
|
||||
"bf8_fp8_fp32": [[16, 16, 16]],
|
||||
"int8_int8_int32": [[16, 16, 16]],
|
||||
},
|
||||
}
|
||||
|
||||
# Preshuffle-specific warp tile combinations (subset of standard GEMM)
|
||||
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_fp32": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp32": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp32": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Preshuffle-supported pipelines
|
||||
PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"]
|
||||
|
||||
# LDS capacity limits per pipeline type (in bytes)
|
||||
LDS_CAPACITY_LIMITS: Dict[str, int] = {
|
||||
"mem": 65536,
|
||||
"compv1": 65536,
|
||||
"compv2": 65536,
|
||||
"compv3": 65536,
|
||||
"compv4": 32768,
|
||||
"compv5": 65536,
|
||||
"preshufflev1": 32768,
|
||||
"preshufflev2": 32768,
|
||||
"default": 65536,
|
||||
}
|
||||
|
||||
# Unsupported trait combinations: (pipeline, epilogue, scheduler)
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
("compv5", "cshuffle", "interwave"),
|
||||
("compv5", "default", "interwave"),
|
||||
("compv6", "cshuffle", "interwave"),
|
||||
("compv6", "default", "interwave"),
|
||||
("comp_async", "cshuffle", "interwave"),
|
||||
("comp_async", "default", "interwave"),
|
||||
}
|
||||
|
||||
# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes
|
||||
DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {
|
||||
"fp32_fp32": {"acc": "fp32", "notes": "Full precision"},
|
||||
"fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"},
|
||||
"bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"},
|
||||
"fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"},
|
||||
"fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"},
|
||||
"bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"},
|
||||
"bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"},
|
||||
"int8_int8": {"acc": "int32", "notes": "Integer GEMM"},
|
||||
"pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"},
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_supported_archs() -> List[str]:
|
||||
"""Get list of all supported GPU architectures."""
|
||||
return list(ARCH_FAMILY_MAP.keys())
|
||||
|
||||
|
||||
def get_arch_family(gpu_arch: str) -> str:
|
||||
"""Get the GPU family for an architecture."""
|
||||
return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown")
|
||||
|
||||
|
||||
def get_element_size(dtype: str) -> float:
|
||||
"""Get element size in bytes for a data type."""
|
||||
return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0)
|
||||
|
||||
|
||||
def get_warp_configs(gpu_arch: str) -> List[List[int]]:
|
||||
"""Get supported warp configurations for an architecture."""
|
||||
return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), [])
|
||||
|
||||
|
||||
def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]:
|
||||
"""Get supported warp tile combinations for arch and data types."""
|
||||
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {})
|
||||
return gpu_combos.get(dtype_key.lower(), [])
|
||||
|
||||
|
||||
def get_lds_limit(pipeline: str) -> int:
|
||||
"""Get LDS capacity limit for a pipeline type."""
|
||||
return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"])
|
||||
|
||||
|
||||
def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is unsupported."""
|
||||
return (
|
||||
pipeline.lower(),
|
||||
epilogue.lower(),
|
||||
scheduler.lower(),
|
||||
) in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]:
|
||||
"""Get accumulator type and notes for a dtype combination."""
|
||||
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
|
||||
return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"})
|
||||
|
||||
|
||||
def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool:
|
||||
"""Check if a dtype combination is valid."""
|
||||
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
|
||||
return key in DTYPE_COMBINATIONS
|
||||
|
||||
|
||||
def get_valid_dtype_combos() -> List[str]:
|
||||
"""Get list of all valid dtype combinations."""
|
||||
return list(DTYPE_COMBINATIONS.keys())
|
||||
27
dispatcher/codegen/default_config.json
Normal file
27
dispatcher/codegen/default_config.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": [128, 256],
|
||||
"tile_n": [128, 256],
|
||||
"tile_k": [32, 64],
|
||||
"warp_m": [2, 4],
|
||||
"warp_n": [2, 4],
|
||||
"warp_k": [1],
|
||||
"warp_tile_m": [16, 32],
|
||||
"warp_tile_n": [16, 32],
|
||||
"warp_tile_k": [16]
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": ["compv4"],
|
||||
"epilogue": ["cshuffle"],
|
||||
"scheduler": ["intrawave"],
|
||||
"pad_m": [false],
|
||||
"pad_n": [false],
|
||||
"pad_k": [false],
|
||||
"persistent": [false, true]
|
||||
},
|
||||
"multi_d_config": {
|
||||
"elementwise_ops": ["MultiDAdd", "Relu", "Gelu"],
|
||||
"num_d_tensors": [1, 2]
|
||||
}
|
||||
}
|
||||
|
||||
452
dispatcher/codegen/generate_arch_specs.py
Normal file
452
dispatcher/codegen/generate_arch_specs.py
Normal file
@@ -0,0 +1,452 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Architecture Specs Generator
|
||||
|
||||
Generates both Python and C++ code from a single JSON source of truth.
|
||||
This ensures consistency between Python codegen and C++ runtime filtering.
|
||||
|
||||
Usage:
|
||||
python generate_arch_specs.py [--json arch_specs.json] [--output-dir .]
|
||||
|
||||
# Regenerate after editing arch_specs.json:
|
||||
python generate_arch_specs.py
|
||||
|
||||
Output:
|
||||
- arch_specs_generated.py (Python module with arch data)
|
||||
- arch_specs_generated.hpp (C++ header with arch data)
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
|
||||
|
||||
def load_arch_specs(json_path: Path) -> Dict[str, Any]:
|
||||
"""Load architecture specifications from JSON file."""
|
||||
with open(json_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def generate_python_module(specs: Dict[str, Any], output_path: Path):
|
||||
"""Generate Python module from arch specs."""
|
||||
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# Extract data
|
||||
archs = specs["architectures"]
|
||||
element_sizes = specs["element_sizes"]
|
||||
pipeline_limits = specs["pipeline_lds_limits"]
|
||||
unsupported = specs["unsupported_trait_combos"]["combinations"]
|
||||
|
||||
# Build warp configs dict
|
||||
warp_configs_str = "{\n"
|
||||
for arch, data in archs.items():
|
||||
warp_configs_str += f' "{arch}": {data["warp_configs"]},\n'
|
||||
warp_configs_str += "}"
|
||||
|
||||
# Build warp tile combos dict
|
||||
warp_tile_str = "{\n"
|
||||
for arch, data in archs.items():
|
||||
warp_tile_str += f' "{arch}": {{\n'
|
||||
for dtype, combos in data["warp_tile_combos"].items():
|
||||
warp_tile_str += f' "{dtype}": {combos},\n'
|
||||
warp_tile_str += " },\n"
|
||||
warp_tile_str += "}"
|
||||
|
||||
# Build arch family map
|
||||
arch_family_str = "{\n"
|
||||
for arch, data in archs.items():
|
||||
arch_family_str += f' "{arch}": "{data["family"]}",\n'
|
||||
arch_family_str += "}"
|
||||
|
||||
# Build unsupported combos set
|
||||
unsupported_str = "{\n"
|
||||
for combo in unsupported:
|
||||
unsupported_str += f' ("{combo[0]}", "{combo[1]}", "{combo[2]}"),\n'
|
||||
unsupported_str += "}"
|
||||
|
||||
# Pipeline LDS limits
|
||||
pipeline_limits_clean = {
|
||||
k: v for k, v in pipeline_limits.items() if not k.startswith("_")
|
||||
}
|
||||
|
||||
# Build dtype combinations dict
|
||||
dtype_combos = specs.get("dtype_combinations", {})
|
||||
dtype_combos_str = "{\n"
|
||||
for key, info in dtype_combos.items():
|
||||
if not key.startswith("_"):
|
||||
dtype_combos_str += f' "{key}": {{"acc": "{info["acc"]}", "notes": "{info["notes"]}"}},\n'
|
||||
dtype_combos_str += "}"
|
||||
|
||||
# Build preshuffle warp tile combos dict (operator-specific)
|
||||
preshuffle_combos = specs.get("preshuffle_warp_tile_combos", {})
|
||||
preshuffle_warp_tile_str = "{\n"
|
||||
for arch, dtype_combos_dict in preshuffle_combos.items():
|
||||
if not arch.startswith("_"):
|
||||
preshuffle_warp_tile_str += f' "{arch}": {{\n'
|
||||
for dtype, combos in dtype_combos_dict.items():
|
||||
preshuffle_warp_tile_str += f' "{dtype}": {combos},\n'
|
||||
preshuffle_warp_tile_str += " },\n"
|
||||
preshuffle_warp_tile_str += "}"
|
||||
|
||||
# Build preshuffle pipelines list
|
||||
preshuffle_pipelines = specs.get("preshuffle_pipelines", {}).get(
|
||||
"supported", ["preshufflev2"]
|
||||
)
|
||||
preshuffle_pipelines_str = str(preshuffle_pipelines)
|
||||
|
||||
content = f'''# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
|
||||
|
||||
Generated from: arch_specs.json
|
||||
Generated at: {timestamp}
|
||||
|
||||
To update this file:
|
||||
1. Edit arch_specs.json
|
||||
2. Run: python generate_arch_specs.py
|
||||
|
||||
This module provides architecture-specific configurations for kernel filtering.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
# =============================================================================
|
||||
# Architecture Data (Generated from arch_specs.json)
|
||||
# =============================================================================
|
||||
|
||||
# GPU architecture to family mapping
|
||||
ARCH_FAMILY_MAP: Dict[str, str] = {arch_family_str}
|
||||
|
||||
# Element size in bytes for each data type
|
||||
ELEMENT_SIZE_MAP: Dict[str, float] = {element_sizes}
|
||||
|
||||
# Supported warp configurations per architecture [warp_m, warp_n, warp_k]
|
||||
WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {warp_configs_str}
|
||||
|
||||
# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...]
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {warp_tile_str}
|
||||
|
||||
# Preshuffle-specific warp tile combinations (subset of standard GEMM)
|
||||
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {preshuffle_warp_tile_str}
|
||||
|
||||
# Preshuffle-supported pipelines
|
||||
PRESHUFFLE_PIPELINES: List[str] = {preshuffle_pipelines_str}
|
||||
|
||||
# LDS capacity limits per pipeline type (in bytes)
|
||||
LDS_CAPACITY_LIMITS: Dict[str, int] = {pipeline_limits_clean}
|
||||
|
||||
# Unsupported trait combinations: (pipeline, epilogue, scheduler)
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {unsupported_str}
|
||||
|
||||
# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes
|
||||
DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {dtype_combos_str}
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def get_supported_archs() -> List[str]:
|
||||
"""Get list of all supported GPU architectures."""
|
||||
return list(ARCH_FAMILY_MAP.keys())
|
||||
|
||||
|
||||
def get_arch_family(gpu_arch: str) -> str:
|
||||
"""Get the GPU family for an architecture."""
|
||||
return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown")
|
||||
|
||||
|
||||
def get_element_size(dtype: str) -> float:
|
||||
"""Get element size in bytes for a data type."""
|
||||
return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0)
|
||||
|
||||
|
||||
def get_warp_configs(gpu_arch: str) -> List[List[int]]:
|
||||
"""Get supported warp configurations for an architecture."""
|
||||
return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), [])
|
||||
|
||||
|
||||
def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]:
|
||||
"""Get supported warp tile combinations for arch and data types."""
|
||||
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {{}})
|
||||
return gpu_combos.get(dtype_key.lower(), [])
|
||||
|
||||
|
||||
def get_lds_limit(pipeline: str) -> int:
|
||||
"""Get LDS capacity limit for a pipeline type."""
|
||||
return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"])
|
||||
|
||||
|
||||
def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is unsupported."""
|
||||
return (pipeline.lower(), epilogue.lower(), scheduler.lower()) in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]:
|
||||
"""Get accumulator type and notes for a dtype combination."""
|
||||
key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}"
|
||||
return DTYPE_COMBINATIONS.get(key, {{"acc": "fp32", "notes": "unknown"}})
|
||||
|
||||
|
||||
def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool:
|
||||
"""Check if a dtype combination is valid."""
|
||||
key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}"
|
||||
return key in DTYPE_COMBINATIONS
|
||||
|
||||
|
||||
def get_valid_dtype_combos() -> List[str]:
|
||||
"""Get list of all valid dtype combinations."""
|
||||
return list(DTYPE_COMBINATIONS.keys())
|
||||
'''
|
||||
|
||||
output_path.write_text(content)
|
||||
print(f"Generated: {output_path}")
|
||||
|
||||
|
||||
def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
|
||||
"""Generate C++ header from arch specs."""
|
||||
|
||||
timestamp = datetime.now().isoformat()
|
||||
|
||||
# Extract data
|
||||
archs = specs["architectures"]
|
||||
element_sizes = specs["element_sizes"]
|
||||
pipeline_limits = specs["pipeline_lds_limits"]
|
||||
specs["unsupported_trait_combos"]["combinations"]
|
||||
|
||||
# Build arch enum and string functions
|
||||
arch_enums = []
|
||||
arch_to_string_cases = []
|
||||
string_to_arch_cases = []
|
||||
|
||||
for arch, data in archs.items():
|
||||
enum_name = arch.upper().replace("GFX", "GFX_")
|
||||
arch_enums.append(f" {enum_name}, // {data['description']}")
|
||||
arch_to_string_cases.append(
|
||||
f' case GpuArch::{enum_name}: return "{arch}";'
|
||||
)
|
||||
string_to_arch_cases.append(
|
||||
f' if (arch_str == "{arch}") return GpuArch::{enum_name};'
|
||||
)
|
||||
|
||||
# Build warp configs switch
|
||||
warp_config_cases = []
|
||||
for arch, data in archs.items():
|
||||
enum_name = arch.upper().replace("GFX", "GFX_")
|
||||
configs = ", ".join(
|
||||
[f"{{{c[0]}, {c[1]}, {c[2]}}}" for c in data["warp_configs"]]
|
||||
)
|
||||
warp_config_cases.append(
|
||||
f" case GpuArch::{enum_name}: return {{{configs}}};"
|
||||
)
|
||||
|
||||
# Build element size switch
|
||||
# Include all data types defined in kernel_key.hpp DataType enum
|
||||
elem_size_cases = []
|
||||
dtype_enum_map = {
|
||||
"fp16": "FP16",
|
||||
"bf16": "BF16",
|
||||
"fp32": "FP32",
|
||||
"fp64": "FP64",
|
||||
"fp8": "FP8",
|
||||
"bf8": "BF8",
|
||||
"int8": "INT8",
|
||||
"int4": "INT4",
|
||||
"int32": "INT32",
|
||||
}
|
||||
for dtype, size in element_sizes.items():
|
||||
if dtype in dtype_enum_map:
|
||||
elem_size_cases.append(
|
||||
f" case DataType::{dtype_enum_map[dtype]}: return {float(size)}f;"
|
||||
)
|
||||
|
||||
# Build LDS limits
|
||||
lds_limit_cases = []
|
||||
pipeline_enum_map = {
|
||||
"mem": "Mem",
|
||||
"compv1": "CompV1",
|
||||
"compv2": "CompV2",
|
||||
"compv3": "CompV3",
|
||||
"compv4": "CompV4",
|
||||
"compv5": "CompV5",
|
||||
"preshufflev1": "PreShuffleV1",
|
||||
"preshufflev2": "PreShuffleV2",
|
||||
}
|
||||
default_lds = pipeline_limits.get("default", 65536)
|
||||
for pipeline, limit in pipeline_limits.items():
|
||||
if pipeline in pipeline_enum_map:
|
||||
lds_limit_cases.append(
|
||||
f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};"
|
||||
)
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/**
|
||||
* AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
|
||||
*
|
||||
* Generated from: arch_specs.json
|
||||
* Generated at: {timestamp}
|
||||
*
|
||||
* To update this file:
|
||||
* 1. Edit arch_specs.json
|
||||
* 2. Run: python generate_arch_specs.py
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/dispatcher/kernel_key.hpp"
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {{
|
||||
namespace dispatcher {{
|
||||
namespace arch_specs {{
|
||||
|
||||
// =============================================================================
|
||||
// GPU Architecture Enum (Generated)
|
||||
// =============================================================================
|
||||
|
||||
enum class GpuArch : std::uint8_t {{
|
||||
{chr(10).join(arch_enums)}
|
||||
UNKNOWN
|
||||
}};
|
||||
|
||||
// =============================================================================
|
||||
// String Conversion Functions (Generated)
|
||||
// =============================================================================
|
||||
|
||||
inline std::string arch_to_string(GpuArch arch) {{
|
||||
switch (arch) {{
|
||||
{chr(10).join(arch_to_string_cases)}
|
||||
default: return "unknown";
|
||||
}}
|
||||
}}
|
||||
|
||||
inline GpuArch string_to_arch(const std::string& arch_str) {{
|
||||
{chr(10).join(string_to_arch_cases)}
|
||||
return GpuArch::UNKNOWN;
|
||||
}}
|
||||
|
||||
// =============================================================================
|
||||
// Element Size (Generated)
|
||||
// =============================================================================
|
||||
|
||||
inline float element_size(DataType dtype) {{
|
||||
switch (dtype) {{
|
||||
{chr(10).join(elem_size_cases)}
|
||||
default: return 2.0f;
|
||||
}}
|
||||
}}
|
||||
|
||||
// =============================================================================
|
||||
// Warp Configurations (Generated)
|
||||
// =============================================================================
|
||||
|
||||
using WarpConfig = std::array<int, 3>;
|
||||
|
||||
inline std::vector<WarpConfig> get_supported_warp_configs(GpuArch arch) {{
|
||||
switch (arch) {{
|
||||
{chr(10).join(warp_config_cases)}
|
||||
default: return {{}};
|
||||
}}
|
||||
}}
|
||||
|
||||
// =============================================================================
|
||||
// LDS Capacity Limits (Generated)
|
||||
// =============================================================================
|
||||
|
||||
inline std::size_t get_lds_capacity(Pipeline pipeline) {{
|
||||
{chr(10).join(lds_limit_cases)}
|
||||
return {default_lds}; // Default
|
||||
}}
|
||||
|
||||
// =============================================================================
|
||||
// Unsupported Trait Combinations (Generated)
|
||||
// =============================================================================
|
||||
|
||||
inline bool is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) {{
|
||||
// Generated from unsupported_trait_combos in arch_specs.json
|
||||
if (scheduler == Scheduler::Interwave) {{
|
||||
if (pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) {{
|
||||
return true;
|
||||
}}
|
||||
}}
|
||||
return false;
|
||||
}}
|
||||
|
||||
}} // namespace arch_specs
|
||||
}} // namespace dispatcher
|
||||
}} // namespace ck_tile
|
||||
"""
|
||||
|
||||
output_path.write_text(content)
|
||||
print(f"Generated: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate Python and C++ code from arch_specs.json"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--json",
|
||||
type=Path,
|
||||
default=SCRIPT_DIR / "arch_specs.json",
|
||||
help="Path to arch_specs.json",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=SCRIPT_DIR,
|
||||
help="Output directory for generated files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpp-output-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Output directory for C++ header (defaults to dispatcher/include/...)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load specs
|
||||
print(f"Loading: {args.json}")
|
||||
specs = load_arch_specs(args.json)
|
||||
|
||||
# Generate Python module
|
||||
py_output = args.output_dir / "arch_specs_generated.py"
|
||||
generate_python_module(specs, py_output)
|
||||
|
||||
# Generate C++ header
|
||||
if args.cpp_output_dir:
|
||||
cpp_output = args.cpp_output_dir / "arch_specs_generated.hpp"
|
||||
else:
|
||||
cpp_output = (
|
||||
SCRIPT_DIR.parent
|
||||
/ "include"
|
||||
/ "ck_tile"
|
||||
/ "dispatcher"
|
||||
/ "arch_specs_generated.hpp"
|
||||
)
|
||||
|
||||
cpp_output.parent.mkdir(parents=True, exist_ok=True)
|
||||
generate_cpp_header(specs, cpp_output)
|
||||
|
||||
print("\nDone! To apply changes:")
|
||||
print(" 1. Python code will automatically use arch_specs_generated.py")
|
||||
print(" 2. C++ code includes arch_specs_generated.hpp")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
429
dispatcher/codegen/generate_dispatcher_registration.py
Normal file
429
dispatcher/codegen/generate_dispatcher_registration.py
Normal file
@@ -0,0 +1,429 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Generate dispatcher registration code for CK Tile kernels
|
||||
|
||||
This script generates C++ registration code that instantiates TileKernelInstance
|
||||
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelConfig:
|
||||
"""Kernel configuration for registration"""
|
||||
|
||||
name: str
|
||||
header_file: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
block_size: int
|
||||
pipeline: str
|
||||
epilogue: str
|
||||
scheduler: str
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
double_buffer: bool
|
||||
transpose_c: bool
|
||||
dtype_a: str = "fp16"
|
||||
dtype_b: str = "fp16"
|
||||
dtype_c: str = "fp16"
|
||||
dtype_acc: str = "fp32"
|
||||
layout_a: str = "row"
|
||||
layout_b: str = "col"
|
||||
layout_c: str = "row"
|
||||
|
||||
|
||||
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
|
||||
"""Generate registration header file"""
|
||||
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
//
|
||||
// AUTO-GENERATED FILE - DO NOT EDIT
|
||||
// Generated by generate_dispatcher_registration.py
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/dispatcher/registry.hpp"
|
||||
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
|
||||
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
|
||||
|
||||
// Include all generated kernel headers
|
||||
"""
|
||||
|
||||
# Add includes for all kernel headers
|
||||
for kernel in kernels:
|
||||
content += f'#include "{kernel.header_file}"\n'
|
||||
|
||||
content += """
|
||||
|
||||
namespace ck_tile {
|
||||
namespace dispatcher {
|
||||
namespace generated {
|
||||
|
||||
/// Register all generated kernels with the dispatcher
|
||||
inline void register_all_kernels(Registry& registry)
|
||||
{
|
||||
"""
|
||||
|
||||
# Add registration calls for each kernel
|
||||
for kernel in kernels:
|
||||
# Extract the SelectedKernel type name from the header file
|
||||
# Assuming the header defines a type like: using SelectedKernel = ...
|
||||
kernel_type = f"SelectedKernel_{kernel.name}"
|
||||
|
||||
content += f""" // Register {kernel.name}
|
||||
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
|
||||
"""
|
||||
|
||||
content += """}
|
||||
|
||||
/// Register all generated kernels with the global registry
|
||||
inline void register_all_kernels()
|
||||
{
|
||||
auto& registry = Registry::instance();
|
||||
register_all_kernels(registry);
|
||||
}
|
||||
|
||||
} // namespace generated
|
||||
} // namespace dispatcher
|
||||
} // namespace ck_tile
|
||||
"""
|
||||
|
||||
output_file.write_text(content)
|
||||
print(f"✓ Generated registration header: {output_file}")
|
||||
|
||||
|
||||
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
|
||||
"""Generate registration implementation file"""
|
||||
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
//
|
||||
// AUTO-GENERATED FILE - DO NOT EDIT
|
||||
// Generated by generate_dispatcher_registration.py
|
||||
|
||||
#include "dispatcher_registration.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace dispatcher {
|
||||
namespace generated {
|
||||
|
||||
// Explicit instantiations to reduce compile time
|
||||
// These ensure the templates are instantiated once
|
||||
|
||||
"""
|
||||
|
||||
for kernel in kernels:
|
||||
kernel_type = f"SelectedKernel_{kernel.name}"
|
||||
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
|
||||
|
||||
content += """
|
||||
} // namespace generated
|
||||
} // namespace dispatcher
|
||||
} // namespace ck_tile
|
||||
"""
|
||||
|
||||
output_file.write_text(content)
|
||||
print(f"✓ Generated registration implementation: {output_file}")
|
||||
|
||||
|
||||
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
|
||||
"""Generate a wrapper header that defines SelectedKernel type"""
|
||||
|
||||
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
|
||||
|
||||
content = f"""// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
//
|
||||
// AUTO-GENERATED FILE - DO NOT EDIT
|
||||
// Generated by generate_dispatcher_registration.py
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "{kernel.header_file}"
|
||||
|
||||
namespace ck_tile {{
|
||||
namespace dispatcher {{
|
||||
namespace generated {{
|
||||
|
||||
// Type alias for dispatcher registration
|
||||
// This allows the registration code to reference the kernel type
|
||||
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
|
||||
|
||||
}} // namespace generated
|
||||
}} // namespace dispatcher
|
||||
}} // namespace ck_tile
|
||||
"""
|
||||
|
||||
wrapper_file.write_text(content)
|
||||
|
||||
|
||||
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
|
||||
"""Load kernel configurations from manifest file"""
|
||||
|
||||
with open(manifest_file, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
kernels = []
|
||||
for kernel_data in data.get("kernels", []):
|
||||
kernel = KernelConfig(
|
||||
name=kernel_data["name"],
|
||||
header_file=kernel_data["header_file"],
|
||||
tile_m=kernel_data["tile_m"],
|
||||
tile_n=kernel_data["tile_n"],
|
||||
tile_k=kernel_data["tile_k"],
|
||||
warp_m=kernel_data.get("warp_m", 2),
|
||||
warp_n=kernel_data.get("warp_n", 2),
|
||||
warp_k=kernel_data.get("warp_k", 1),
|
||||
warp_tile_m=kernel_data.get("warp_tile_m", 32),
|
||||
warp_tile_n=kernel_data.get("warp_tile_n", 32),
|
||||
warp_tile_k=kernel_data.get("warp_tile_k", 16),
|
||||
block_size=kernel_data.get("block_size", 256),
|
||||
pipeline=kernel_data.get("pipeline", "compv4"),
|
||||
epilogue=kernel_data.get("epilogue", "cshuffle"),
|
||||
scheduler=kernel_data.get("scheduler", "intrawave"),
|
||||
pad_m=kernel_data.get("pad_m", False),
|
||||
pad_n=kernel_data.get("pad_n", False),
|
||||
pad_k=kernel_data.get("pad_k", False),
|
||||
persistent=kernel_data.get("persistent", False),
|
||||
double_buffer=kernel_data.get("double_buffer", True),
|
||||
transpose_c=kernel_data.get("transpose_c", False),
|
||||
dtype_a=kernel_data.get("dtype_a", "fp16"),
|
||||
dtype_b=kernel_data.get("dtype_b", "fp16"),
|
||||
dtype_c=kernel_data.get("dtype_c", "fp16"),
|
||||
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
|
||||
)
|
||||
kernels.append(kernel)
|
||||
|
||||
return kernels
|
||||
|
||||
|
||||
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
|
||||
"""Scan generated headers and extract kernel configurations"""
|
||||
|
||||
import re
|
||||
|
||||
kernels = []
|
||||
|
||||
for header_file in generated_dir.glob("**/*.hpp"):
|
||||
try:
|
||||
content = header_file.read_text()
|
||||
|
||||
# Extract kernel name
|
||||
name_match = re.search(
|
||||
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
|
||||
)
|
||||
if not name_match:
|
||||
continue
|
||||
|
||||
kernel_name = name_match.group(1)
|
||||
|
||||
# Extract tile configuration (support ck_tile::index_t)
|
||||
tile_m_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
tile_n_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
tile_k_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
|
||||
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
|
||||
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
|
||||
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
|
||||
|
||||
# Extract warp configuration
|
||||
warp_m_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
warp_n_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
warp_k_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
|
||||
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
|
||||
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
|
||||
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
|
||||
|
||||
# Extract warp tile configuration
|
||||
warp_tile_m_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
warp_tile_n_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
warp_tile_k_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
|
||||
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
|
||||
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
|
||||
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
|
||||
|
||||
# Extract other parameters (with defaults)
|
||||
block_size_match = re.search(
|
||||
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
|
||||
content,
|
||||
)
|
||||
block_size = int(block_size_match.group(1)) if block_size_match else 256
|
||||
|
||||
# Extract boolean flags
|
||||
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
|
||||
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
|
||||
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
|
||||
persistent = (
|
||||
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
|
||||
)
|
||||
double_buffer = (
|
||||
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
|
||||
)
|
||||
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
|
||||
|
||||
kernel = KernelConfig(
|
||||
name=kernel_name,
|
||||
header_file=str(header_file.relative_to(generated_dir.parent)),
|
||||
tile_m=tile_m,
|
||||
tile_n=tile_n,
|
||||
tile_k=tile_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
warp_tile_m=warp_tile_m,
|
||||
warp_tile_n=warp_tile_n,
|
||||
warp_tile_k=warp_tile_k,
|
||||
block_size=block_size,
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=pad_m,
|
||||
pad_n=pad_n,
|
||||
pad_k=pad_k,
|
||||
persistent=persistent,
|
||||
double_buffer=double_buffer,
|
||||
transpose_c=transpose_c,
|
||||
)
|
||||
|
||||
kernels.append(kernel)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse {header_file}: {e}")
|
||||
continue
|
||||
|
||||
return kernels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate dispatcher registration code"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing generated kernel headers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for registration code",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--manifest", type=str, help="Optional manifest file with kernel configurations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scan",
|
||||
action="store_true",
|
||||
help="Scan generated headers instead of using manifest",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generated_dir = Path(args.generated_dir)
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load kernel configurations
|
||||
if args.manifest:
|
||||
print(f"Loading kernels from manifest: {args.manifest}")
|
||||
kernels = load_kernel_manifest(Path(args.manifest))
|
||||
elif args.scan:
|
||||
print(f"Scanning generated headers in: {generated_dir}")
|
||||
kernels = scan_generated_headers(generated_dir)
|
||||
else:
|
||||
print("Error: Must specify either --manifest or --scan")
|
||||
return 1
|
||||
|
||||
print(f"Found {len(kernels)} kernels")
|
||||
|
||||
# Generate registration code
|
||||
registration_header = output_dir / "dispatcher_registration.hpp"
|
||||
registration_cpp = output_dir / "dispatcher_registration.cpp"
|
||||
|
||||
generate_registration_header(kernels, registration_header)
|
||||
generate_registration_cpp(kernels, registration_cpp)
|
||||
|
||||
# Generate manifest for Python
|
||||
manifest_output = output_dir / "kernels_manifest.json"
|
||||
manifest_data = {
|
||||
"kernels": [
|
||||
{
|
||||
"name": k.name,
|
||||
"header_file": k.header_file,
|
||||
"tile_m": k.tile_m,
|
||||
"tile_n": k.tile_n,
|
||||
"tile_k": k.tile_k,
|
||||
"block_size": k.block_size,
|
||||
"persistent": k.persistent,
|
||||
}
|
||||
for k in kernels
|
||||
]
|
||||
}
|
||||
|
||||
with open(manifest_output, "w") as f:
|
||||
json.dump(manifest_data, f, indent=2)
|
||||
|
||||
print(f"✓ Generated manifest: {manifest_output}")
|
||||
print("\n✓ Registration code generation complete!")
|
||||
print(f" Total kernels: {len(kernels)}")
|
||||
print(" Output files:")
|
||||
print(f" - {registration_header}")
|
||||
print(f" - {registration_cpp}")
|
||||
print(f" - {manifest_output}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
||||
430
dispatcher/codegen/generate_kernel_wrappers.py
Normal file
430
dispatcher/codegen/generate_kernel_wrappers.py
Normal file
@@ -0,0 +1,430 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Generate one .cpp wrapper file per kernel header for maximum parallel compilation.
|
||||
|
||||
Each kernel becomes its own translation unit, enabling:
|
||||
- Maximum parallelism with make -j$(nproc)
|
||||
- Per-kernel build progress (e.g., [5/128] Building kernel: gemm_fp16_128x128)
|
||||
- Incremental rebuilds (only changed kernels recompile)
|
||||
- Fine-grained build time analysis
|
||||
|
||||
Usage:
|
||||
python3 generate_kernel_wrappers.py --kernel-dir build/generated_kernels --output-dir build/kernel_wrappers
|
||||
|
||||
Output structure:
|
||||
build/kernel_wrappers/
|
||||
├── gemm_fp16_rcr_128x128x32.cpp
|
||||
├── gemm_fp16_rcr_256x256x64.cpp
|
||||
├── conv_fwd_fp16_2d_128x128.cpp
|
||||
└── ...
|
||||
|
||||
Each .cpp simply includes its corresponding .hpp and forces symbol emission.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
WRAPPER_TEMPLATE_GEMM = """// SPDX-License-Identifier: MIT
|
||||
// Auto-generated wrapper for: {kernel_name}
|
||||
// This file enables per-kernel parallel compilation
|
||||
|
||||
#include "{kernel_hpp}"
|
||||
|
||||
// Force symbol emission for kernel registration
|
||||
namespace ck_tile {{
|
||||
namespace dispatcher {{
|
||||
namespace generated {{
|
||||
|
||||
// Marker to prevent dead code elimination
|
||||
volatile bool _{kernel_id}_registered = true;
|
||||
|
||||
}} // namespace generated
|
||||
}} // namespace dispatcher
|
||||
}} // namespace ck_tile
|
||||
"""
|
||||
|
||||
WRAPPER_TEMPLATE_CONV = """// SPDX-License-Identifier: MIT
|
||||
// Auto-generated wrapper for: {kernel_name}
|
||||
// This file enables per-kernel parallel compilation
|
||||
|
||||
#include "{kernel_hpp}"
|
||||
|
||||
namespace ck_tile {{
|
||||
namespace dispatcher {{
|
||||
namespace generated {{
|
||||
|
||||
volatile bool _{kernel_id}_registered = true;
|
||||
|
||||
}} // namespace generated
|
||||
}} // namespace dispatcher
|
||||
}} // namespace ck_tile
|
||||
"""
|
||||
|
||||
|
||||
def generate_wrapper(
|
||||
kernel_hpp: Path, output_dir: Path, index: int, total: int
|
||||
) -> Tuple[Path, bool]:
|
||||
"""Generate a .cpp wrapper for a single kernel header."""
|
||||
kernel_name = kernel_hpp.stem
|
||||
kernel_id = kernel_name.replace("-", "_").replace(".", "_")
|
||||
|
||||
# Select template based on kernel type
|
||||
if kernel_name.startswith("gemm"):
|
||||
template = WRAPPER_TEMPLATE_GEMM
|
||||
else:
|
||||
template = WRAPPER_TEMPLATE_CONV
|
||||
|
||||
content = template.format(
|
||||
kernel_name=kernel_name,
|
||||
kernel_hpp=kernel_hpp.name,
|
||||
kernel_id=kernel_id,
|
||||
)
|
||||
|
||||
output_cpp = output_dir / f"{kernel_name}.cpp"
|
||||
|
||||
# Only write if content changed (for incremental builds)
|
||||
if output_cpp.exists():
|
||||
existing = output_cpp.read_text()
|
||||
if existing == content:
|
||||
return output_cpp, False # No change
|
||||
|
||||
output_cpp.write_text(content)
|
||||
return output_cpp, True # Written
|
||||
|
||||
|
||||
def generate_cmake_list(
|
||||
wrappers: List[Path], output_dir: Path, kernel_dir: Path
|
||||
) -> Path:
|
||||
"""Generate CMakeLists.txt that compiles each wrapper as a separate object."""
|
||||
|
||||
num_kernels = len(wrappers)
|
||||
|
||||
cmake_content = f'''# SPDX-License-Identifier: MIT
|
||||
# Auto-generated CMakeLists.txt for per-kernel parallel compilation
|
||||
# Generated {num_kernels} kernel translation units
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
# =============================================================================
|
||||
# Per-Kernel Object Targets ({num_kernels} kernels)
|
||||
# =============================================================================
|
||||
# Each kernel is compiled as a separate OBJECT library for maximum parallelism.
|
||||
# Build with: make -j$(nproc) all_kernels
|
||||
#
|
||||
# Progress output:
|
||||
# [ 1/{num_kernels}] Building kernel: gemm_fp16_rcr_128x128x32
|
||||
# [ 2/{num_kernels}] Building kernel: gemm_fp16_rcr_256x256x64
|
||||
# ...
|
||||
|
||||
set(KERNEL_INCLUDE_DIR "{kernel_dir}")
|
||||
set(ALL_KERNEL_OBJECTS "")
|
||||
|
||||
'''
|
||||
|
||||
for idx, wrapper in enumerate(wrappers, 1):
|
||||
kernel_name = wrapper.stem
|
||||
obj_target = f"kobj_{kernel_name}"
|
||||
|
||||
cmake_content += f"""
|
||||
# [{idx}/{num_kernels}] {kernel_name}
|
||||
add_library({obj_target} OBJECT {wrapper.name})
|
||||
target_include_directories({obj_target} PRIVATE ${{KERNEL_INCLUDE_DIR}} ${{CK_INCLUDE_DIR}})
|
||||
target_compile_options({obj_target} PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
set_target_properties({obj_target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
if(hip_FOUND)
|
||||
target_link_libraries({obj_target} PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
list(APPEND ALL_KERNEL_OBJECTS $<TARGET_OBJECTS:{obj_target}>)
|
||||
"""
|
||||
|
||||
cmake_content += f"""
|
||||
|
||||
# =============================================================================
|
||||
# Combined Kernel Library
|
||||
# =============================================================================
|
||||
# Links all {num_kernels} kernel objects into a single shared library
|
||||
|
||||
add_library(all_kernels SHARED ${{ALL_KERNEL_OBJECTS}})
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(all_kernels PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
set_target_properties(all_kernels PROPERTIES
|
||||
POSITION_INDEPENDENT_CODE ON
|
||||
OUTPUT_NAME "dispatcher_kernels"
|
||||
)
|
||||
|
||||
message(STATUS "Configured {num_kernels} kernel objects for parallel compilation")
|
||||
message(STATUS "Build with: make -j$(nproc) all_kernels")
|
||||
"""
|
||||
|
||||
cmake_file = output_dir / "CMakeLists.txt"
|
||||
cmake_file.write_text(cmake_content)
|
||||
return cmake_file
|
||||
|
||||
|
||||
def generate_ninja_build(
|
||||
wrappers: List[Path], output_dir: Path, kernel_dir: Path
|
||||
) -> Path:
|
||||
"""Generate build.ninja for even faster parallel compilation."""
|
||||
|
||||
num_kernels = len(wrappers)
|
||||
|
||||
ninja_content = f"""# SPDX-License-Identifier: MIT
|
||||
# Auto-generated build.ninja for per-kernel parallel compilation
|
||||
# {num_kernels} kernel translation units
|
||||
|
||||
# Variables
|
||||
cxx = hipcc
|
||||
cxxflags = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal --offload-compress
|
||||
includes = -I{kernel_dir} -I/opt/rocm/include
|
||||
|
||||
# Rules
|
||||
rule compile
|
||||
command = $cxx $cxxflags $includes -c $in -o $out
|
||||
description = [{num_kernels}] Building kernel: $kernel_name
|
||||
|
||||
rule link
|
||||
command = $cxx -shared $in -o $out -L/opt/rocm/lib -lamdhip64
|
||||
description = Linking: $out
|
||||
|
||||
# Kernel objects
|
||||
"""
|
||||
|
||||
obj_files = []
|
||||
for idx, wrapper in enumerate(wrappers, 1):
|
||||
kernel_name = wrapper.stem
|
||||
obj_file = f"{kernel_name}.o"
|
||||
obj_files.append(obj_file)
|
||||
|
||||
ninja_content += f"""
|
||||
build {obj_file}: compile {wrapper.name}
|
||||
kernel_name = {kernel_name}
|
||||
"""
|
||||
|
||||
ninja_content += f"""
|
||||
|
||||
# Shared library
|
||||
build libdispatcher_kernels.so: link {" ".join(obj_files)}
|
||||
|
||||
# Default target
|
||||
default libdispatcher_kernels.so
|
||||
"""
|
||||
|
||||
ninja_file = output_dir / "build.ninja"
|
||||
ninja_file.write_text(ninja_content)
|
||||
return ninja_file
|
||||
|
||||
|
||||
def generate_makefile(wrappers: List[Path], output_dir: Path, kernel_dir: Path) -> Path:
|
||||
"""Generate Makefile for per-kernel parallel compilation."""
|
||||
|
||||
num_kernels = len(wrappers)
|
||||
kernel_names = [w.stem for w in wrappers]
|
||||
obj_files = [f"{name}.o" for name in kernel_names]
|
||||
|
||||
makefile_content = f"""# SPDX-License-Identifier: MIT
|
||||
# Auto-generated Makefile for per-kernel parallel compilation
|
||||
# {num_kernels} kernel translation units
|
||||
#
|
||||
# Usage:
|
||||
# make -j$(nproc) # Build all kernels in parallel
|
||||
# make -j$(nproc) VERBOSE=1 # With per-kernel progress
|
||||
# make clean # Remove all objects
|
||||
|
||||
CXX = hipcc
|
||||
CXXFLAGS = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 \\
|
||||
-Wno-undefined-func-template -Wno-float-equal --offload-compress
|
||||
INCLUDES = -I{kernel_dir} -I/opt/rocm/include
|
||||
LDFLAGS = -shared -L/opt/rocm/lib -lamdhip64
|
||||
|
||||
TARGET = libdispatcher_kernels.so
|
||||
OBJECTS = {" ".join(obj_files)}
|
||||
|
||||
# Progress counter (only works with make -j1, use ninja for parallel progress)
|
||||
TOTAL_KERNELS = {num_kernels}
|
||||
CURRENT = 0
|
||||
|
||||
.PHONY: all clean
|
||||
|
||||
all: $(TARGET)
|
||||
\t@echo "Built $(TARGET) with {num_kernels} kernels"
|
||||
|
||||
$(TARGET): $(OBJECTS)
|
||||
\t@echo "[LINK] Linking {num_kernels} kernel objects -> $@"
|
||||
\t$(CXX) $(LDFLAGS) $^ -o $@
|
||||
|
||||
"""
|
||||
|
||||
for idx, (wrapper, obj) in enumerate(zip(wrappers, obj_files), 1):
|
||||
kernel_name = wrapper.stem
|
||||
makefile_content += f"""
|
||||
{obj}: {wrapper.name}
|
||||
\t@echo "[{idx}/{num_kernels}] Building kernel: {kernel_name}"
|
||||
\t$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
|
||||
"""
|
||||
|
||||
makefile_content += f"""
|
||||
|
||||
clean:
|
||||
\trm -f $(OBJECTS) $(TARGET)
|
||||
\t@echo "Cleaned {num_kernels} kernel objects"
|
||||
"""
|
||||
|
||||
makefile = output_dir / "Makefile"
|
||||
makefile.write_text(makefile_content)
|
||||
return makefile
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate per-kernel wrapper .cpp files for parallel compilation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kernel-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing generated kernel .hpp files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for wrapper .cpp files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
type=str,
|
||||
default="*.hpp",
|
||||
help="Glob pattern for kernel headers (default: *.hpp)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate-cmake",
|
||||
action="store_true",
|
||||
help="Generate CMakeLists.txt for the wrappers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate-ninja",
|
||||
action="store_true",
|
||||
help="Generate build.ninja for ninja builds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate-makefile",
|
||||
action="store_true",
|
||||
help="Generate Makefile for make builds",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--parallel",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Generate wrappers in parallel (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Find kernel headers
|
||||
kernel_dir = args.kernel_dir.resolve()
|
||||
if not kernel_dir.exists():
|
||||
print(f"Error: Kernel directory not found: {kernel_dir}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
kernel_headers = sorted(kernel_dir.glob(args.pattern))
|
||||
if not kernel_headers:
|
||||
print(
|
||||
f"Error: No kernel headers found matching {args.pattern} in {kernel_dir}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
|
||||
num_kernels = len(kernel_headers)
|
||||
print(f"Found {num_kernels} kernel headers in {kernel_dir}")
|
||||
|
||||
# Create output directory
|
||||
output_dir = args.output_dir.resolve()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate wrappers
|
||||
print(f"Generating {num_kernels} wrapper .cpp files...")
|
||||
|
||||
wrappers = []
|
||||
written = 0
|
||||
|
||||
if args.parallel and num_kernels > 1:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = {
|
||||
executor.submit(
|
||||
generate_wrapper, hpp, output_dir, idx, num_kernels
|
||||
): hpp
|
||||
for idx, hpp in enumerate(kernel_headers, 1)
|
||||
}
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
wrapper_path, was_written = future.result()
|
||||
wrappers.append(wrapper_path)
|
||||
if was_written:
|
||||
written += 1
|
||||
if args.verbose:
|
||||
print(f" Generated: {wrapper_path.name}")
|
||||
else:
|
||||
for idx, hpp in enumerate(kernel_headers, 1):
|
||||
wrapper_path, was_written = generate_wrapper(
|
||||
hpp, output_dir, idx, num_kernels
|
||||
)
|
||||
wrappers.append(wrapper_path)
|
||||
if was_written:
|
||||
written += 1
|
||||
if args.verbose:
|
||||
print(f" [{idx}/{num_kernels}] Generated: {wrapper_path.name}")
|
||||
|
||||
wrappers.sort(key=lambda p: p.name)
|
||||
|
||||
print(
|
||||
f" Total: {num_kernels} wrappers ({written} written, {num_kernels - written} unchanged)"
|
||||
)
|
||||
|
||||
# Generate build files
|
||||
if args.generate_cmake:
|
||||
cmake_file = generate_cmake_list(wrappers, output_dir, kernel_dir)
|
||||
print(f" Generated: {cmake_file}")
|
||||
|
||||
if args.generate_ninja:
|
||||
ninja_file = generate_ninja_build(wrappers, output_dir, kernel_dir)
|
||||
print(f" Generated: {ninja_file}")
|
||||
|
||||
if args.generate_makefile:
|
||||
makefile = generate_makefile(wrappers, output_dir, kernel_dir)
|
||||
print(f" Generated: {makefile}")
|
||||
|
||||
print(f"\nOutput directory: {output_dir}")
|
||||
print(f"Kernels ready for parallel compilation: {num_kernels}")
|
||||
print("\nTo build:")
|
||||
print(f" cd {output_dir}")
|
||||
if args.generate_makefile:
|
||||
print(" make -j$(nproc) # Parallel build with progress")
|
||||
if args.generate_ninja:
|
||||
print(" ninja # Fast parallel build")
|
||||
if args.generate_cmake:
|
||||
print(" cmake -B build && cmake --build build -j$(nproc)")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
798
dispatcher/codegen/kernel_config_loader.py
Normal file
798
dispatcher/codegen/kernel_config_loader.py
Normal file
@@ -0,0 +1,798 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Kernel Configuration Loader
|
||||
|
||||
Load kernel configurations from JSON files for generating specific kernel sets.
|
||||
Compatible with tile_engine JSON format.
|
||||
|
||||
Usage:
|
||||
from kernel_config_loader import load_kernel_configs, KernelConfigSet
|
||||
|
||||
# Load configs from JSON
|
||||
config_set = load_kernel_configs("my_kernels.json")
|
||||
|
||||
# Get all configurations (cartesian product of all parameter values)
|
||||
for config in config_set.generate_configs():
|
||||
print(config)
|
||||
|
||||
# Use with codegen
|
||||
from unified_gemm_codegen import UnifiedGemmCodegen
|
||||
codegen = UnifiedGemmCodegen(...)
|
||||
codegen.generate_from_configs(config_set.generate_configs())
|
||||
"""
|
||||
|
||||
import json
|
||||
import itertools
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Iterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Tile configuration for a kernel"""
|
||||
|
||||
tile_m: int = 128
|
||||
tile_n: int = 128
|
||||
tile_k: int = 32
|
||||
warp_m: int = 2
|
||||
warp_n: int = 2
|
||||
warp_k: int = 1
|
||||
warp_tile_m: int = 32
|
||||
warp_tile_n: int = 32
|
||||
warp_tile_k: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Trait configuration for a kernel (order matches GEMM/Conv TraitConfig)"""
|
||||
|
||||
pipeline: str = "compv4"
|
||||
epilogue: str = "cshuffle"
|
||||
scheduler: str = "intrawave"
|
||||
pad_m: bool = False
|
||||
pad_n: bool = False
|
||||
pad_k: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelConfig:
|
||||
"""Complete kernel configuration"""
|
||||
|
||||
tile: TileConfig = field(default_factory=TileConfig)
|
||||
trait: TraitConfig = field(default_factory=TraitConfig)
|
||||
dtype_a: str = "fp16"
|
||||
dtype_b: str = "fp16"
|
||||
dtype_c: str = "fp16"
|
||||
dtype_acc: str = "fp32"
|
||||
layout: str = "rcr"
|
||||
gpu_target: str = "gfx942"
|
||||
variant: str = "standard"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for codegen"""
|
||||
return {
|
||||
"tile_m": self.tile.tile_m,
|
||||
"tile_n": self.tile.tile_n,
|
||||
"tile_k": self.tile.tile_k,
|
||||
"warp_m": self.tile.warp_m,
|
||||
"warp_n": self.tile.warp_n,
|
||||
"warp_k": self.tile.warp_k,
|
||||
"warp_tile_m": self.tile.warp_tile_m,
|
||||
"warp_tile_n": self.tile.warp_tile_n,
|
||||
"warp_tile_k": self.tile.warp_tile_k,
|
||||
"pipeline": self.trait.pipeline,
|
||||
"scheduler": self.trait.scheduler,
|
||||
"epilogue": self.trait.epilogue,
|
||||
"pad_m": self.trait.pad_m,
|
||||
"pad_n": self.trait.pad_n,
|
||||
"pad_k": self.trait.pad_k,
|
||||
"dtype_a": self.dtype_a,
|
||||
"dtype_b": self.dtype_b,
|
||||
"dtype_c": self.dtype_c,
|
||||
"dtype_acc": self.dtype_acc,
|
||||
"layout": self.layout,
|
||||
"gpu_target": self.gpu_target,
|
||||
"variant": self.variant,
|
||||
}
|
||||
|
||||
def kernel_name(self) -> str:
|
||||
"""Generate kernel name from config"""
|
||||
name = f"gemm_{self.dtype_a}_{self.layout}_{self.trait.pipeline}"
|
||||
name += f"_{self.trait.epilogue}_{self.trait.scheduler}"
|
||||
name += f"_{str(self.trait.pad_m).capitalize()}"
|
||||
name += f"_{str(self.trait.pad_n).capitalize()}"
|
||||
name += f"_{str(self.trait.pad_k).capitalize()}"
|
||||
name += "_False" # preshuffle
|
||||
name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}"
|
||||
name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}"
|
||||
name += (
|
||||
f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelConfigSet:
|
||||
"""A set of kernel configurations loaded from JSON"""
|
||||
|
||||
name: str = "default"
|
||||
configs: List[KernelConfig] = field(default_factory=list)
|
||||
|
||||
# Parameter ranges for generation
|
||||
tile_m_values: List[int] = field(default_factory=lambda: [128])
|
||||
tile_n_values: List[int] = field(default_factory=lambda: [128])
|
||||
tile_k_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_m_values: List[int] = field(default_factory=lambda: [2])
|
||||
warp_n_values: List[int] = field(default_factory=lambda: [2])
|
||||
warp_k_values: List[int] = field(default_factory=lambda: [1])
|
||||
warp_tile_m_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_tile_n_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_tile_k_values: List[int] = field(default_factory=lambda: [16])
|
||||
|
||||
pipeline_values: List[str] = field(default_factory=lambda: ["compv4"])
|
||||
scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"])
|
||||
epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"])
|
||||
pad_m_values: List[bool] = field(default_factory=lambda: [False])
|
||||
pad_n_values: List[bool] = field(default_factory=lambda: [False])
|
||||
pad_k_values: List[bool] = field(default_factory=lambda: [False])
|
||||
|
||||
dtype_a: str = "fp16"
|
||||
dtype_b: str = "fp16"
|
||||
dtype_c: str = "fp16"
|
||||
dtype_acc: str = "fp32"
|
||||
layout: str = "rcr"
|
||||
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
|
||||
variant: str = "standard"
|
||||
|
||||
def generate_configs(self) -> Iterator[KernelConfig]:
|
||||
"""Generate all kernel configurations (cartesian product)"""
|
||||
# Tile parameters
|
||||
tile_params = itertools.product(
|
||||
self.tile_m_values,
|
||||
self.tile_n_values,
|
||||
self.tile_k_values,
|
||||
self.warp_m_values,
|
||||
self.warp_n_values,
|
||||
self.warp_k_values,
|
||||
self.warp_tile_m_values,
|
||||
self.warp_tile_n_values,
|
||||
self.warp_tile_k_values,
|
||||
)
|
||||
|
||||
# Trait parameters
|
||||
trait_params = itertools.product(
|
||||
self.pipeline_values,
|
||||
self.scheduler_values,
|
||||
self.epilogue_values,
|
||||
self.pad_m_values,
|
||||
self.pad_n_values,
|
||||
self.pad_k_values,
|
||||
)
|
||||
|
||||
# Convert to lists for reuse
|
||||
tile_list = list(tile_params)
|
||||
trait_list = list(trait_params)
|
||||
|
||||
# Generate for each GPU target
|
||||
for gpu_target in self.gpu_targets:
|
||||
for tile in tile_list:
|
||||
for trait in trait_list:
|
||||
tile_cfg = TileConfig(
|
||||
tile_m=tile[0],
|
||||
tile_n=tile[1],
|
||||
tile_k=tile[2],
|
||||
warp_m=tile[3],
|
||||
warp_n=tile[4],
|
||||
warp_k=tile[5],
|
||||
warp_tile_m=tile[6],
|
||||
warp_tile_n=tile[7],
|
||||
warp_tile_k=tile[8],
|
||||
)
|
||||
trait_cfg = TraitConfig(
|
||||
pipeline=trait[0],
|
||||
scheduler=trait[1],
|
||||
epilogue=trait[2],
|
||||
pad_m=trait[3],
|
||||
pad_n=trait[4],
|
||||
pad_k=trait[5],
|
||||
)
|
||||
yield KernelConfig(
|
||||
tile=tile_cfg,
|
||||
trait=trait_cfg,
|
||||
dtype_a=self.dtype_a,
|
||||
dtype_b=self.dtype_b,
|
||||
dtype_c=self.dtype_c,
|
||||
dtype_acc=self.dtype_acc,
|
||||
layout=self.layout,
|
||||
gpu_target=gpu_target,
|
||||
variant=self.variant,
|
||||
)
|
||||
|
||||
def config_count(self) -> int:
|
||||
"""Get total number of configurations"""
|
||||
tile_count = (
|
||||
len(self.tile_m_values)
|
||||
* len(self.tile_n_values)
|
||||
* len(self.tile_k_values)
|
||||
* len(self.warp_m_values)
|
||||
* len(self.warp_n_values)
|
||||
* len(self.warp_k_values)
|
||||
* len(self.warp_tile_m_values)
|
||||
* len(self.warp_tile_n_values)
|
||||
* len(self.warp_tile_k_values)
|
||||
)
|
||||
trait_count = (
|
||||
len(self.pipeline_values)
|
||||
* len(self.scheduler_values)
|
||||
* len(self.epilogue_values)
|
||||
* len(self.pad_m_values)
|
||||
* len(self.pad_n_values)
|
||||
* len(self.pad_k_values)
|
||||
)
|
||||
return tile_count * trait_count * len(self.gpu_targets)
|
||||
|
||||
|
||||
def _get_values(config: Dict, key: str, default: List) -> List:
|
||||
"""Extract values from config dict, handling range specifications"""
|
||||
if key not in config:
|
||||
return default
|
||||
|
||||
item = config[key]
|
||||
|
||||
# Explicit values list
|
||||
if "values" in item:
|
||||
return item["values"]
|
||||
|
||||
# Range specification (min, max, step)
|
||||
if "min" in item and "max" in item:
|
||||
min_val = item["min"]
|
||||
max_val = item["max"]
|
||||
step = item.get("step", 1)
|
||||
return list(range(min_val, max_val + 1, step))
|
||||
|
||||
return default
|
||||
|
||||
|
||||
def load_kernel_configs(json_path: str | Path) -> KernelConfigSet:
|
||||
"""
|
||||
Load kernel configurations from a JSON file.
|
||||
|
||||
Supports both tile_engine format and dispatcher format.
|
||||
|
||||
Args:
|
||||
json_path: Path to JSON configuration file
|
||||
|
||||
Returns:
|
||||
KernelConfigSet with all parameter values loaded
|
||||
"""
|
||||
json_path = Path(json_path)
|
||||
|
||||
with open(json_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
config_set = KernelConfigSet()
|
||||
|
||||
# Name
|
||||
config_set.name = data.get("kernel_set_name", json_path.stem)
|
||||
|
||||
# Data types
|
||||
if "datatype" in data:
|
||||
dt = data["datatype"]
|
||||
config_set.dtype_a = dt.get("a", "fp16")
|
||||
config_set.dtype_b = dt.get("b", "fp16")
|
||||
config_set.dtype_c = dt.get("c", "fp16")
|
||||
config_set.dtype_acc = dt.get("acc", "fp32")
|
||||
|
||||
# Layout
|
||||
config_set.layout = data.get("layout", "rcr")
|
||||
|
||||
# GPU targets
|
||||
if "gpu_targets" in data:
|
||||
config_set.gpu_targets = data["gpu_targets"]
|
||||
elif "gpu_target" in data:
|
||||
config_set.gpu_targets = [data["gpu_target"]]
|
||||
|
||||
# Variant
|
||||
config_set.variant = data.get("variant", "standard")
|
||||
|
||||
# Tile config
|
||||
tile_cfg = data.get("tile_config", {})
|
||||
config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128])
|
||||
config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128])
|
||||
config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32])
|
||||
config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2])
|
||||
config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2])
|
||||
config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1])
|
||||
config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32])
|
||||
config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32])
|
||||
config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16])
|
||||
|
||||
# Trait config
|
||||
trait_cfg = data.get("trait_config", {})
|
||||
config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv4"])
|
||||
config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"])
|
||||
config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"])
|
||||
config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [False])
|
||||
config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [False])
|
||||
config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [False])
|
||||
|
||||
return config_set
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Convolution Configuration Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvTileConfig:
|
||||
"""Tile configuration for a convolution kernel"""
|
||||
|
||||
tile_m: int = 128 # M dimension (N * spatial_out for fwd)
|
||||
tile_n: int = 128 # N dimension (K output channels for fwd)
|
||||
tile_k: int = 32 # K dimension (C * filter for fwd)
|
||||
warp_m: int = 2
|
||||
warp_n: int = 2
|
||||
warp_k: int = 1
|
||||
warp_tile_m: int = 32
|
||||
warp_tile_n: int = 32
|
||||
warp_tile_k: int = 16
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvTraitConfig:
|
||||
"""Trait configuration for a convolution kernel"""
|
||||
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
epilogue: str = "cshuffle"
|
||||
pad_m: bool = True
|
||||
pad_n: bool = True
|
||||
pad_k: bool = True
|
||||
double_smem_buffer: bool = False
|
||||
num_groups_to_merge: int = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvKernelConfig:
|
||||
"""Complete convolution kernel configuration"""
|
||||
|
||||
tile: ConvTileConfig = field(default_factory=ConvTileConfig)
|
||||
trait: ConvTraitConfig = field(default_factory=ConvTraitConfig)
|
||||
dtype_input: str = "fp16"
|
||||
dtype_weight: str = "fp16"
|
||||
dtype_output: str = "fp16"
|
||||
dtype_acc: str = "fp32"
|
||||
variant: str = "forward" # forward, bwd_data, bwd_weight
|
||||
ndim: int = 2 # 1, 2, or 3
|
||||
layout: str = "nhwgc"
|
||||
gpu_target: str = "gfx942"
|
||||
|
||||
# Vector sizes
|
||||
vector_size_a: int = 4
|
||||
vector_size_b: int = 8
|
||||
vector_size_c: int = 8
|
||||
|
||||
# Occupancy
|
||||
block_per_cu: int = 1
|
||||
num_wave_groups: int = 1
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for codegen"""
|
||||
return {
|
||||
"tile_m": self.tile.tile_m,
|
||||
"tile_n": self.tile.tile_n,
|
||||
"tile_k": self.tile.tile_k,
|
||||
"warp_m": self.tile.warp_m,
|
||||
"warp_n": self.tile.warp_n,
|
||||
"warp_k": self.tile.warp_k,
|
||||
"warp_tile_m": self.tile.warp_tile_m,
|
||||
"warp_tile_n": self.tile.warp_tile_n,
|
||||
"warp_tile_k": self.tile.warp_tile_k,
|
||||
"pipeline": self.trait.pipeline,
|
||||
"scheduler": self.trait.scheduler,
|
||||
"epilogue": self.trait.epilogue,
|
||||
"pad_m": self.trait.pad_m,
|
||||
"pad_n": self.trait.pad_n,
|
||||
"pad_k": self.trait.pad_k,
|
||||
"double_smem_buffer": self.trait.double_smem_buffer,
|
||||
"num_groups_to_merge": self.trait.num_groups_to_merge,
|
||||
"dtype_input": self.dtype_input,
|
||||
"dtype_weight": self.dtype_weight,
|
||||
"dtype_output": self.dtype_output,
|
||||
"dtype_acc": self.dtype_acc,
|
||||
"variant": self.variant,
|
||||
"ndim": self.ndim,
|
||||
"layout": self.layout,
|
||||
"gpu_target": self.gpu_target,
|
||||
"vector_size_a": self.vector_size_a,
|
||||
"vector_size_b": self.vector_size_b,
|
||||
"vector_size_c": self.vector_size_c,
|
||||
"block_per_cu": self.block_per_cu,
|
||||
"num_wave_groups": self.num_wave_groups,
|
||||
}
|
||||
|
||||
def kernel_name(self) -> str:
|
||||
"""Generate kernel name from config"""
|
||||
variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"}
|
||||
var_str = variant_map.get(self.variant, self.variant)
|
||||
|
||||
name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d"
|
||||
name += f"_{self.trait.pipeline}_{self.trait.epilogue}_{self.trait.scheduler}"
|
||||
name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}"
|
||||
name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}"
|
||||
name += (
|
||||
f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvKernelConfigSet:
|
||||
"""A set of convolution kernel configurations loaded from JSON"""
|
||||
|
||||
name: str = "default"
|
||||
configs: List[ConvKernelConfig] = field(default_factory=list)
|
||||
|
||||
# Tile parameter ranges
|
||||
tile_m_values: List[int] = field(default_factory=lambda: [128])
|
||||
tile_n_values: List[int] = field(default_factory=lambda: [128])
|
||||
tile_k_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_m_values: List[int] = field(default_factory=lambda: [2])
|
||||
warp_n_values: List[int] = field(default_factory=lambda: [2])
|
||||
warp_k_values: List[int] = field(default_factory=lambda: [1])
|
||||
warp_tile_m_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_tile_n_values: List[int] = field(default_factory=lambda: [32])
|
||||
warp_tile_k_values: List[int] = field(default_factory=lambda: [16])
|
||||
|
||||
# Trait parameter ranges
|
||||
pipeline_values: List[str] = field(default_factory=lambda: ["compv3"])
|
||||
scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"])
|
||||
epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"])
|
||||
pad_m_values: List[bool] = field(default_factory=lambda: [True])
|
||||
pad_n_values: List[bool] = field(default_factory=lambda: [True])
|
||||
pad_k_values: List[bool] = field(default_factory=lambda: [True])
|
||||
double_smem_buffer_values: List[bool] = field(default_factory=lambda: [False])
|
||||
num_groups_to_merge_values: List[int] = field(default_factory=lambda: [1])
|
||||
|
||||
# Vector sizes
|
||||
vector_size_a_values: List[int] = field(default_factory=lambda: [4])
|
||||
vector_size_b_values: List[int] = field(default_factory=lambda: [8])
|
||||
vector_size_c_values: List[int] = field(default_factory=lambda: [8])
|
||||
|
||||
# Occupancy
|
||||
block_per_cu_values: List[int] = field(default_factory=lambda: [1])
|
||||
num_wave_groups_values: List[int] = field(default_factory=lambda: [1])
|
||||
|
||||
# Data types
|
||||
dtype_input: str = "fp16"
|
||||
dtype_weight: str = "fp16"
|
||||
dtype_output: str = "fp16"
|
||||
dtype_acc: str = "fp32"
|
||||
|
||||
# Conv specific
|
||||
variant: str = "forward"
|
||||
ndim: int = 2
|
||||
layout: str = "nhwgc"
|
||||
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
|
||||
|
||||
def generate_configs(self) -> Iterator[ConvKernelConfig]:
|
||||
"""Generate all kernel configurations (cartesian product)"""
|
||||
# Tile parameters
|
||||
tile_params = itertools.product(
|
||||
self.tile_m_values,
|
||||
self.tile_n_values,
|
||||
self.tile_k_values,
|
||||
self.warp_m_values,
|
||||
self.warp_n_values,
|
||||
self.warp_k_values,
|
||||
self.warp_tile_m_values,
|
||||
self.warp_tile_n_values,
|
||||
self.warp_tile_k_values,
|
||||
)
|
||||
|
||||
# Trait parameters
|
||||
trait_params = itertools.product(
|
||||
self.pipeline_values,
|
||||
self.scheduler_values,
|
||||
self.epilogue_values,
|
||||
self.pad_m_values,
|
||||
self.pad_n_values,
|
||||
self.pad_k_values,
|
||||
self.double_smem_buffer_values,
|
||||
self.num_groups_to_merge_values,
|
||||
)
|
||||
|
||||
# Vector/occupancy parameters
|
||||
extra_params = itertools.product(
|
||||
self.vector_size_a_values,
|
||||
self.vector_size_b_values,
|
||||
self.vector_size_c_values,
|
||||
self.block_per_cu_values,
|
||||
self.num_wave_groups_values,
|
||||
)
|
||||
|
||||
# Convert to lists for reuse
|
||||
tile_list = list(tile_params)
|
||||
trait_list = list(trait_params)
|
||||
extra_list = list(extra_params)
|
||||
|
||||
# Generate for each GPU target
|
||||
for gpu_target in self.gpu_targets:
|
||||
for tile in tile_list:
|
||||
for trait in trait_list:
|
||||
for extra in extra_list:
|
||||
tile_cfg = ConvTileConfig(
|
||||
tile_m=tile[0],
|
||||
tile_n=tile[1],
|
||||
tile_k=tile[2],
|
||||
warp_m=tile[3],
|
||||
warp_n=tile[4],
|
||||
warp_k=tile[5],
|
||||
warp_tile_m=tile[6],
|
||||
warp_tile_n=tile[7],
|
||||
warp_tile_k=tile[8],
|
||||
)
|
||||
trait_cfg = ConvTraitConfig(
|
||||
pipeline=trait[0],
|
||||
scheduler=trait[1],
|
||||
epilogue=trait[2],
|
||||
pad_m=trait[3],
|
||||
pad_n=trait[4],
|
||||
pad_k=trait[5],
|
||||
double_smem_buffer=trait[6],
|
||||
num_groups_to_merge=trait[7],
|
||||
)
|
||||
yield ConvKernelConfig(
|
||||
tile=tile_cfg,
|
||||
trait=trait_cfg,
|
||||
dtype_input=self.dtype_input,
|
||||
dtype_weight=self.dtype_weight,
|
||||
dtype_output=self.dtype_output,
|
||||
dtype_acc=self.dtype_acc,
|
||||
variant=self.variant,
|
||||
ndim=self.ndim,
|
||||
layout=self.layout,
|
||||
gpu_target=gpu_target,
|
||||
vector_size_a=extra[0],
|
||||
vector_size_b=extra[1],
|
||||
vector_size_c=extra[2],
|
||||
block_per_cu=extra[3],
|
||||
num_wave_groups=extra[4],
|
||||
)
|
||||
|
||||
def config_count(self) -> int:
|
||||
"""Get total number of configurations"""
|
||||
tile_count = (
|
||||
len(self.tile_m_values)
|
||||
* len(self.tile_n_values)
|
||||
* len(self.tile_k_values)
|
||||
* len(self.warp_m_values)
|
||||
* len(self.warp_n_values)
|
||||
* len(self.warp_k_values)
|
||||
* len(self.warp_tile_m_values)
|
||||
* len(self.warp_tile_n_values)
|
||||
* len(self.warp_tile_k_values)
|
||||
)
|
||||
trait_count = (
|
||||
len(self.pipeline_values)
|
||||
* len(self.scheduler_values)
|
||||
* len(self.epilogue_values)
|
||||
* len(self.pad_m_values)
|
||||
* len(self.pad_n_values)
|
||||
* len(self.pad_k_values)
|
||||
* len(self.double_smem_buffer_values)
|
||||
* len(self.num_groups_to_merge_values)
|
||||
)
|
||||
extra_count = (
|
||||
len(self.vector_size_a_values)
|
||||
* len(self.vector_size_b_values)
|
||||
* len(self.vector_size_c_values)
|
||||
* len(self.block_per_cu_values)
|
||||
* len(self.num_wave_groups_values)
|
||||
)
|
||||
return tile_count * trait_count * extra_count * len(self.gpu_targets)
|
||||
|
||||
|
||||
def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
|
||||
"""
|
||||
Load convolution kernel configurations from a JSON file.
|
||||
|
||||
Args:
|
||||
json_path: Path to JSON configuration file
|
||||
|
||||
Returns:
|
||||
ConvKernelConfigSet with all parameter values loaded
|
||||
"""
|
||||
json_path = Path(json_path)
|
||||
|
||||
with open(json_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
config_set = ConvKernelConfigSet()
|
||||
|
||||
# Name
|
||||
config_set.name = data.get("kernel_set_name", json_path.stem)
|
||||
|
||||
# Data types
|
||||
if "datatype" in data:
|
||||
dt = data["datatype"]
|
||||
config_set.dtype_input = dt.get("input", "fp16")
|
||||
config_set.dtype_weight = dt.get("weight", "fp16")
|
||||
config_set.dtype_output = dt.get("output", "fp16")
|
||||
config_set.dtype_acc = dt.get("acc", "fp32")
|
||||
|
||||
# Conv specific
|
||||
config_set.variant = data.get("variant", "forward")
|
||||
config_set.ndim = data.get("ndim", 2)
|
||||
config_set.layout = data.get("layout", "nhwgc")
|
||||
|
||||
# GPU targets
|
||||
if "gpu_targets" in data:
|
||||
config_set.gpu_targets = data["gpu_targets"]
|
||||
elif "gpu_target" in data:
|
||||
config_set.gpu_targets = [data["gpu_target"]]
|
||||
|
||||
# Tile config
|
||||
tile_cfg = data.get("tile_config", {})
|
||||
config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128])
|
||||
config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128])
|
||||
config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32])
|
||||
config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2])
|
||||
config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2])
|
||||
config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1])
|
||||
config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32])
|
||||
config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32])
|
||||
config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16])
|
||||
|
||||
# Trait config
|
||||
trait_cfg = data.get("trait_config", {})
|
||||
config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv3"])
|
||||
config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"])
|
||||
config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"])
|
||||
config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [True])
|
||||
config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [True])
|
||||
config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [True])
|
||||
config_set.double_smem_buffer_values = _get_values(
|
||||
trait_cfg, "double_smem_buffer", [False]
|
||||
)
|
||||
config_set.num_groups_to_merge_values = _get_values(
|
||||
trait_cfg, "num_groups_to_merge", [1]
|
||||
)
|
||||
|
||||
# Vector config
|
||||
vec_cfg = data.get("vector_config", {})
|
||||
config_set.vector_size_a_values = _get_values(vec_cfg, "vector_size_a", [4])
|
||||
config_set.vector_size_b_values = _get_values(vec_cfg, "vector_size_b", [8])
|
||||
config_set.vector_size_c_values = _get_values(vec_cfg, "vector_size_c", [8])
|
||||
|
||||
# Occupancy config
|
||||
occ_cfg = data.get("occupancy_config", {})
|
||||
config_set.block_per_cu_values = _get_values(occ_cfg, "block_per_cu", [1])
|
||||
config_set.num_wave_groups_values = _get_values(occ_cfg, "num_wave_groups", [1])
|
||||
|
||||
return config_set
|
||||
|
||||
|
||||
def generate_cpp_conv_kernel_set_declaration(
|
||||
config_set: ConvKernelConfigSet,
|
||||
set_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet.
|
||||
"""
|
||||
name = set_name or config_set.name
|
||||
|
||||
lines = [f"DECL_CONV_KERNEL_SET({name},"]
|
||||
|
||||
for config in config_set.generate_configs():
|
||||
line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, '
|
||||
line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})"
|
||||
lines.append(line)
|
||||
|
||||
lines.append(");")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GEMM Configuration Export Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def generate_cpp_kernel_set_declaration(
|
||||
config_set: KernelConfigSet,
|
||||
set_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate C++ DECL_KERNEL_SET code from a KernelConfigSet.
|
||||
|
||||
Args:
|
||||
config_set: The kernel configuration set
|
||||
set_name: Optional name override for the kernel set
|
||||
|
||||
Returns:
|
||||
C++ code string with DECL_KERNEL_SET declaration
|
||||
"""
|
||||
name = set_name or config_set.name
|
||||
|
||||
lines = [f"DECL_KERNEL_SET({name},"]
|
||||
|
||||
for config in config_set.generate_configs():
|
||||
# Generate .add() call for each config
|
||||
line = f' .add("{config.dtype_a}", "{config.layout}", '
|
||||
line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})"
|
||||
lines.append(line)
|
||||
|
||||
lines.append(");")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# CLI for testing
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python kernel_config_loader.py <config.json>")
|
||||
print("\nLoads kernel configurations from JSON and prints summary.")
|
||||
sys.exit(1)
|
||||
|
||||
json_path = sys.argv[1]
|
||||
|
||||
try:
|
||||
config_set = load_kernel_configs(json_path)
|
||||
|
||||
print(f"Kernel Set: {config_set.name}")
|
||||
print(
|
||||
f"Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}, Acc={config_set.dtype_acc}"
|
||||
)
|
||||
print(f"Layout: {config_set.layout}")
|
||||
print(f"GPU Targets: {config_set.gpu_targets}")
|
||||
print(f"Variant: {config_set.variant}")
|
||||
print()
|
||||
print("Tile Configurations:")
|
||||
print(f" tile_m: {config_set.tile_m_values}")
|
||||
print(f" tile_n: {config_set.tile_n_values}")
|
||||
print(f" tile_k: {config_set.tile_k_values}")
|
||||
print(f" warp_m: {config_set.warp_m_values}")
|
||||
print(f" warp_n: {config_set.warp_n_values}")
|
||||
print(f" warp_k: {config_set.warp_k_values}")
|
||||
print(
|
||||
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
|
||||
)
|
||||
print()
|
||||
print("Trait Configurations:")
|
||||
print(f" pipeline: {config_set.pipeline_values}")
|
||||
print(f" scheduler: {config_set.scheduler_values}")
|
||||
print(f" epilogue: {config_set.epilogue_values}")
|
||||
print(
|
||||
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
|
||||
)
|
||||
print()
|
||||
print(f"Total configurations: {config_set.config_count()}")
|
||||
print()
|
||||
|
||||
# Print first few config names
|
||||
print("Sample kernel names:")
|
||||
for i, config in enumerate(config_set.generate_configs()):
|
||||
if i >= 5:
|
||||
print(f" ... and {config_set.config_count() - 5} more")
|
||||
break
|
||||
print(f" {config.kernel_name()}")
|
||||
print()
|
||||
|
||||
# Generate C++ code
|
||||
if "--cpp" in sys.argv:
|
||||
print("C++ Declaration:")
|
||||
print("-" * 60)
|
||||
print(generate_cpp_kernel_set_declaration(config_set))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
518
dispatcher/codegen/preselected_kernels.py
Normal file
518
dispatcher/codegen/preselected_kernels.py
Normal file
@@ -0,0 +1,518 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Preselected, Benchmarked Kernel Configurations
|
||||
|
||||
Curated kernel sets optimized for different workload characteristics:
|
||||
- Compute-friendly: Large tiles, high arithmetic intensity
|
||||
- Memory-friendly: Smaller tiles, better memory access patterns
|
||||
- Latency-friendly: Minimal tiles, low latency for small problems
|
||||
"""
|
||||
|
||||
from functools import partial, lru_cache
|
||||
from typing import List
|
||||
from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Base Configurations
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _base_fp16_rcr_compute() -> partial:
|
||||
"""Base configuration for compute-intensive FP16 RCR kernels"""
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
def _base_fp16_rcr_memory() -> partial:
|
||||
"""Base configuration for memory-intensive FP16 RCR kernels"""
|
||||
# Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave)
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="mem",
|
||||
epilogue="cshuffle",
|
||||
scheduler="interwave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=128,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
def _base_fp16_rcr_latency() -> partial:
|
||||
"""Base configuration for latency-sensitive FP16 RCR kernels"""
|
||||
return partial(
|
||||
KernelConfig,
|
||||
tile=None, # Will be overridden
|
||||
trait=TraitConfig(
|
||||
pipeline="mem",
|
||||
epilogue="default",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=128,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Preselected FP16 RCR Kernels
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_compute() -> List[KernelConfig]:
|
||||
"""
|
||||
Compute-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Large M, N dimensions (>= 128)
|
||||
- High arithmetic intensity
|
||||
- Good occupancy
|
||||
- Maximum throughput
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
return [
|
||||
# Large tiles for maximum compute
|
||||
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)),
|
||||
# Balanced tiles
|
||||
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
# With persistent kernel for large batches
|
||||
base(
|
||||
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=False,
|
||||
pad_n=False,
|
||||
pad_k=False,
|
||||
persistent=True,
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_memory() -> List[KernelConfig]:
|
||||
"""
|
||||
Memory-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Small to medium M, N dimensions
|
||||
- Memory-bound workloads
|
||||
- Better cache utilization
|
||||
- Lower register pressure
|
||||
"""
|
||||
base = _base_fp16_rcr_memory()
|
||||
|
||||
return [
|
||||
# Small tiles for memory efficiency
|
||||
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)),
|
||||
# Medium tiles
|
||||
base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_latency() -> List[KernelConfig]:
|
||||
"""
|
||||
Latency-friendly FP16 RCR kernels
|
||||
|
||||
Optimized for:
|
||||
- Very small M, N dimensions (< 64)
|
||||
- Minimal launch overhead
|
||||
- Low latency
|
||||
- Quick execution
|
||||
"""
|
||||
base = _base_fp16_rcr_latency()
|
||||
|
||||
return [
|
||||
# Minimal tiles for low latency
|
||||
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
|
||||
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Preselected Multi-D Kernels
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_multi_d() -> List[KernelConfig]:
|
||||
"""
|
||||
Multi-D GEMM kernels with element-wise fusion
|
||||
|
||||
Common fusions:
|
||||
- MultiDAdd: E = C + D0 + D1
|
||||
- Relu: E = max(C, 0)
|
||||
- Gelu: E = gelu(C)
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
configs = []
|
||||
|
||||
# Best-performing tile for fused operations
|
||||
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
|
||||
|
||||
# Common element-wise operations
|
||||
for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]:
|
||||
for num_d in [1, 2]:
|
||||
configs.append(
|
||||
base(
|
||||
tile=tile,
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op=ew_op,
|
||||
num_d_tensors=num_d,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]:
|
||||
"""
|
||||
Preshuffle GEMM kernels for weight optimization
|
||||
|
||||
Best for:
|
||||
- Repeated use of same weights
|
||||
- Inference workloads
|
||||
- Batch size > 1
|
||||
"""
|
||||
base = _base_fp16_rcr_compute()
|
||||
|
||||
return [
|
||||
base(
|
||||
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
|
||||
variant=GemmVariant.PRESHUFFLE,
|
||||
preshuffle=True,
|
||||
),
|
||||
base(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.PRESHUFFLE,
|
||||
preshuffle=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unified Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_all() -> List[KernelConfig]:
|
||||
"""All preselected FP16 RCR kernels"""
|
||||
return (
|
||||
preselected_fp16_rcr_compute()
|
||||
+ preselected_fp16_rcr_memory()
|
||||
+ preselected_fp16_rcr_latency()
|
||||
+ preselected_fp16_rcr_multi_d()
|
||||
+ preselected_fp16_rcr_preshuffle()
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp16_rcr_essential() -> List[KernelConfig]:
|
||||
"""
|
||||
Essential FP16 RCR kernels - minimal set for most workloads
|
||||
|
||||
Covers:
|
||||
- 90% of common GEMM sizes
|
||||
- Key fusion operations
|
||||
- Balanced performance
|
||||
"""
|
||||
base_compute = _base_fp16_rcr_compute()
|
||||
base_memory = _base_fp16_rcr_memory()
|
||||
|
||||
return [
|
||||
# Top compute kernels
|
||||
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
# Top memory kernels
|
||||
base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
|
||||
base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
|
||||
# Essential fusions
|
||||
base_compute(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op="Relu",
|
||||
num_d_tensors=1,
|
||||
),
|
||||
base_compute(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
variant=GemmVariant.MULTI_D,
|
||||
elementwise_op="Gelu",
|
||||
num_d_tensors=1,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Default Fallback
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def default_kernel() -> KernelConfig:
|
||||
"""
|
||||
Default fallback kernel - guaranteed to work
|
||||
|
||||
Known-good configuration tested on gfx942
|
||||
"""
|
||||
return KernelConfig(
|
||||
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
k_block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# BF16 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_bf16_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential BF16 RCR kernels"""
|
||||
base_compute = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INT8 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_int8_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential INT8 RCR kernels for quantized inference"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FP8 Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_fp8_rcr_essential() -> List[KernelConfig]:
|
||||
"""Essential FP8 RCR kernels for AI training"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mixed Precision Preselected Sets
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@lru_cache(None)
|
||||
def preselected_mixed_precision() -> List[KernelConfig]:
|
||||
"""Mixed-precision kernels (FP16 inputs, FP32 output)"""
|
||||
base = partial(
|
||||
KernelConfig,
|
||||
tile=None,
|
||||
trait=TraitConfig(
|
||||
pipeline="compv4",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
persistent=False,
|
||||
),
|
||||
variant=GemmVariant.STANDARD,
|
||||
block_size=256,
|
||||
)
|
||||
|
||||
return [
|
||||
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
|
||||
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
PRESELECTED_SETS = {
|
||||
# FP16 sets
|
||||
"fp16_rcr_compute": preselected_fp16_rcr_compute,
|
||||
"fp16_rcr_memory": preselected_fp16_rcr_memory,
|
||||
"fp16_rcr_latency": preselected_fp16_rcr_latency,
|
||||
"fp16_rcr_multi_d": preselected_fp16_rcr_multi_d,
|
||||
"fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle,
|
||||
"fp16_rcr_all": preselected_fp16_rcr_all,
|
||||
"fp16_rcr_essential": preselected_fp16_rcr_essential,
|
||||
# BF16 sets
|
||||
"bf16_rcr_essential": preselected_bf16_rcr_essential,
|
||||
# INT8 sets
|
||||
"int8_rcr_essential": preselected_int8_rcr_essential,
|
||||
# FP8 sets
|
||||
"fp8_rcr_essential": preselected_fp8_rcr_essential,
|
||||
# Mixed precision
|
||||
"mixed_precision": preselected_mixed_precision,
|
||||
}
|
||||
|
||||
|
||||
def get_preselected_set(name: str) -> List[KernelConfig]:
|
||||
"""Get a preselected kernel set by name"""
|
||||
if name not in PRESELECTED_SETS:
|
||||
raise ValueError(
|
||||
f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}"
|
||||
)
|
||||
return PRESELECTED_SETS[name]()
|
||||
|
||||
|
||||
def list_preselected_sets() -> List[str]:
|
||||
"""List all available preselected sets"""
|
||||
return list(PRESELECTED_SETS.keys())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI for testing
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="List preselected kernel configurations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set",
|
||||
type=str,
|
||||
default="fp16_rcr_essential",
|
||||
choices=list_preselected_sets(),
|
||||
help="Preselected set to display",
|
||||
)
|
||||
parser.add_argument("--count-only", action="store_true", help="Only show count")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
configs = get_preselected_set(args.set)
|
||||
|
||||
if args.count_only:
|
||||
print(f"{args.set}: {len(configs)} kernels")
|
||||
else:
|
||||
print(f"Preselected set: {args.set}")
|
||||
print(f"Total kernels: {len(configs)}\n")
|
||||
for i, cfg in enumerate(configs, 1):
|
||||
print(f"{i}. {cfg.variant.value}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}")
|
||||
if cfg.variant == GemmVariant.MULTI_D:
|
||||
print(
|
||||
f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}"
|
||||
)
|
||||
print()
|
||||
1713
dispatcher/codegen/unified_gemm_codegen.py
Executable file
1713
dispatcher/codegen/unified_gemm_codegen.py
Executable file
File diff suppressed because it is too large
Load Diff
448
dispatcher/examples/CMakeLists.txt
Normal file
448
dispatcher/examples/CMakeLists.txt
Normal file
@@ -0,0 +1,448 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
|
||||
# Get processor count for parallel builds
|
||||
include(ProcessorCount)
|
||||
ProcessorCount(NPROC)
|
||||
if(NPROC EQUAL 0)
|
||||
set(NPROC 4)
|
||||
endif()
|
||||
|
||||
# GPU target architecture (passed from command line or default to gfx942)
|
||||
if(NOT DEFINED GPU_TARGETS OR GPU_TARGETS STREQUAL "")
|
||||
set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture target")
|
||||
endif()
|
||||
# Extract first target if multiple are provided (we only support single target builds)
|
||||
string(REPLACE ";" " " GPU_TARGETS_SPACE "${GPU_TARGETS}")
|
||||
string(REPLACE " " ";" GPU_TARGETS_LIST "${GPU_TARGETS_SPACE}")
|
||||
list(GET GPU_TARGETS_LIST 0 GPU_TARGET)
|
||||
message(STATUS "Building for GPU target: ${GPU_TARGET}")
|
||||
|
||||
# NOTE: Per-kernel compilation is now automatic via declarative examples
|
||||
# Each example generates only its declared kernels (from DECL_KERNEL_SET)
|
||||
|
||||
# Link to dispatcher library
|
||||
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build)
|
||||
|
||||
# =============================================================================
|
||||
# Kernel Output Directory
|
||||
# =============================================================================
|
||||
|
||||
set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels")
|
||||
file(MAKE_DIRECTORY ${KERNEL_OUTPUT_DIR})
|
||||
|
||||
# =============================================================================
|
||||
# Kernel Generation Targets (run during 'make', not 'cmake')
|
||||
# =============================================================================
|
||||
|
||||
# Sentinel files to track generation
|
||||
set(GEMM_SENTINEL "${KERNEL_OUTPUT_DIR}/.gemm_generated")
|
||||
|
||||
# Generate GEMM kernels (standard + preshuffle + multi_d) - runs with internal parallelism
|
||||
# Note: 4-char layout "rcrr" means A=row, B=col, C=row, D=row (for multi-d)
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_SENTINEL}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
|
||||
--datatype fp16 --layout rcrr --variants standard preshuffle multi_d
|
||||
--output ${KERNEL_OUTPUT_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
|
||||
COMMENT "Generating GEMM kernels (fp16, rcrr, standard + preshuffle + multi_d) with internal parallelism..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_gemm_kernels
|
||||
DEPENDS ${GEMM_SENTINEL}
|
||||
COMMENT "GEMM kernel generation target"
|
||||
)
|
||||
|
||||
# Alias for generate_all_kernels (GEMM only now)
|
||||
add_custom_target(generate_all_kernels
|
||||
DEPENDS generate_gemm_kernels
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Per-Kernel Compilation (Maximum Parallelism)
|
||||
# =============================================================================
|
||||
# Enable with: cmake -DPER_KERNEL_COMPILATION=ON
|
||||
#
|
||||
# This creates ONE translation unit per kernel, enabling:
|
||||
# 1. Maximum parallelism with make -j$(nproc)
|
||||
# 2. Per-kernel build progress: "[1/128] Building kernel: gemm_fp16_128x128"
|
||||
# 3. Incremental rebuilds (only changed kernels recompile)
|
||||
# 4. Fine-grained build time analysis
|
||||
#
|
||||
# Build process:
|
||||
# 1. Generate kernel headers (.hpp)
|
||||
# 2. Generate wrapper files (.cpp) - one per kernel
|
||||
# 3. Compile each wrapper in parallel
|
||||
# 4. Link all objects into libdispatcher_kernels.so
|
||||
#
|
||||
# Example output:
|
||||
# [ 1/128] Building kernel: gemm_fp16_rcr_128x128x32
|
||||
# [ 2/128] Building kernel: gemm_fp16_rcr_256x256x64
|
||||
# ...
|
||||
# [128/128] Linking: libdispatcher_kernels.so
|
||||
# =============================================================================
|
||||
|
||||
set(WRAPPER_DIR "${CMAKE_BINARY_DIR}/kernel_wrappers")
|
||||
set(WRAPPER_SENTINEL "${WRAPPER_DIR}/.wrappers_generated")
|
||||
|
||||
# Target: Generate wrapper .cpp files (one per kernel)
|
||||
add_custom_command(
|
||||
OUTPUT ${WRAPPER_SENTINEL}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/generate_kernel_wrappers.py
|
||||
--kernel-dir ${KERNEL_OUTPUT_DIR}
|
||||
--output-dir ${WRAPPER_DIR}
|
||||
--generate-makefile
|
||||
--generate-cmake
|
||||
COMMAND ${CMAKE_COMMAND} -E touch ${WRAPPER_SENTINEL}
|
||||
DEPENDS ${GEMM_SENTINEL}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
|
||||
COMMENT "Generating per-kernel wrapper .cpp files..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_kernel_wrappers
|
||||
DEPENDS ${WRAPPER_SENTINEL}
|
||||
COMMENT "Kernel wrapper generation target"
|
||||
)
|
||||
|
||||
# Target: Build kernels using generated Makefile (true per-kernel progress)
|
||||
add_custom_target(build_kernels_parallel
|
||||
COMMAND ${CMAKE_COMMAND} -E echo "Building kernels with per-kernel progress..."
|
||||
COMMAND make -C ${WRAPPER_DIR} -j${NPROC} 2>&1 | grep -E "^\\[|Built|Linking|Error"
|
||||
DEPENDS generate_kernel_wrappers
|
||||
WORKING_DIRECTORY ${WRAPPER_DIR}
|
||||
COMMENT "Compiling kernels in parallel (one translation unit per kernel)..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Global kernel build (optional - prefer per-example builds for minimal compilation)
|
||||
# This builds ALL kernels into a shared library - use for Python bindings or full library
|
||||
# For C++ examples, use declarative approach which builds only needed kernels
|
||||
add_custom_target(dispatcher_kernels
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/parallel_kernel_builder.py
|
||||
--kernel-dir ${KERNEL_OUTPUT_DIR}
|
||||
--output-dir ${CMAKE_BINARY_DIR}
|
||||
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
|
||||
--jobs ${NPROC}
|
||||
DEPENDS generate_all_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts
|
||||
COMMENT "Building ALL kernels in parallel (prefer per-example builds for minimal compilation)..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Force regeneration targets (useful when you want to regenerate)
|
||||
# =============================================================================
|
||||
|
||||
add_custom_target(regenerate_gemm_kernels
|
||||
COMMAND ${CMAKE_COMMAND} -E remove -f ${GEMM_SENTINEL}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
|
||||
--datatype fp16 --layout rcr --variants standard preshuffle multi_d
|
||||
--output ${KERNEL_OUTPUT_DIR}
|
||||
COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
|
||||
COMMENT "Force regenerating GEMM kernels (standard + preshuffle + multi_d)..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(regenerate_all_kernels
|
||||
DEPENDS regenerate_gemm_kernels
|
||||
)
|
||||
|
||||
# Clean all per-example kernel directories
|
||||
add_custom_target(clean_example_kernels
|
||||
COMMAND ${CMAKE_COMMAND} -E echo "Removing per-example kernel directories..."
|
||||
COMMAND find ${CMAKE_BINARY_DIR} -maxdepth 1 -type d -name "*_kernels" -exec rm -rf {} +
|
||||
COMMENT "Cleaning all per-example kernel directories..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Helper function to add a GPU example with force-included kernel
|
||||
# =============================================================================
|
||||
|
||||
# Helper for GPU examples that use the dispatcher registry
|
||||
# KERNEL_HEADER can be:
|
||||
# - A registration header (register_all_kernels.hpp) - included directly in source
|
||||
# - A specific kernel header - force-included via compiler flag
|
||||
function(add_gpu_example NAME SOURCE KERNEL_HEADER)
|
||||
add_executable(${NAME} ${SOURCE})
|
||||
|
||||
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
|
||||
|
||||
target_include_directories(${NAME} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels/dispatcher_wrappers # Wrapper headers
|
||||
)
|
||||
|
||||
# Check if using registration header (no force-include needed)
|
||||
get_filename_component(HEADER_NAME ${KERNEL_HEADER} NAME)
|
||||
if(HEADER_NAME STREQUAL "register_all_kernels.hpp")
|
||||
# Registration header - examples include it directly
|
||||
target_compile_options(${NAME} PRIVATE
|
||||
-DGEMM_KERNEL_AVAILABLE=1
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
else()
|
||||
# Specific kernel header - force-include it
|
||||
target_compile_options(${NAME} PRIVATE
|
||||
-include ${KERNEL_HEADER}
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
endif()
|
||||
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Helper for standalone GPU examples (instantiate kernel directly, no pre-generated header)
|
||||
function(add_standalone_gpu_example NAME SOURCE)
|
||||
add_executable(${NAME} ${SOURCE})
|
||||
|
||||
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
|
||||
|
||||
target_include_directories(${NAME} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels (optional)
|
||||
)
|
||||
|
||||
target_compile_options(${NAME} PRIVATE
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Helper for declarative examples (configuration demo, still needs HIP compiler for CK headers)
|
||||
function(add_declarative_example NAME SOURCE)
|
||||
add_executable(${NAME} ${SOURCE})
|
||||
|
||||
target_include_directories(${NAME} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include
|
||||
)
|
||||
|
||||
target_compile_options(${NAME} PRIVATE
|
||||
-Wno-float-equal
|
||||
-Wno-unused-variable
|
||||
-Wno-undefined-func-template
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
)
|
||||
|
||||
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
|
||||
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# =============================================================================
|
||||
# GEMM Examples
|
||||
# =============================================================================
|
||||
|
||||
# Per-example kernel directories are created from DECL_KERNEL_SET declarations
|
||||
# Each example gets its own: build/<name>_kernels/
|
||||
# This prevents clashes during parallel compilation of multiple examples.
|
||||
|
||||
# Helper function to add example with declarative kernel support
|
||||
# Parses DECL_KERNEL_SET from source and generates ONLY the declared kernels
|
||||
# This enables minimal builds: only kernels needed by this example are generated
|
||||
#
|
||||
# Key features:
|
||||
# - Per-example kernel directories: build/<name>_kernels/ (no clashes)
|
||||
# - Automatic header inclusion: No hardcoded #include needed in source
|
||||
# - Minimal builds: Only declared kernels are generated
|
||||
# - Auto-regeneration: Kernels regenerated if directory missing
|
||||
# - Parallel compilation: Each kernel is a separate translation unit
|
||||
function(add_declarative_gpu_example NAME SOURCE)
|
||||
set(EXAMPLE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE}")
|
||||
get_filename_component(EXAMPLE_STEM ${SOURCE} NAME_WE)
|
||||
|
||||
# Per-example kernel directories
|
||||
set(EXAMPLE_KERNEL_DIR "${CMAKE_BINARY_DIR}/${NAME}_kernels")
|
||||
set(EXAMPLE_HEADER "${EXAMPLE_KERNEL_DIR}/${EXAMPLE_STEM}_kernels.hpp")
|
||||
set(EXAMPLE_LIB "${EXAMPLE_KERNEL_DIR}/lib${NAME}_kernels.a")
|
||||
set(EXAMPLE_SENTINEL "${EXAMPLE_KERNEL_DIR}/.generated")
|
||||
|
||||
# Generate AND compile kernels in parallel at make time
|
||||
# This avoids slow cmake and gets per-kernel progress
|
||||
add_custom_command(
|
||||
OUTPUT ${EXAMPLE_SENTINEL} ${EXAMPLE_LIB}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py
|
||||
${EXAMPLE_SOURCE}
|
||||
--output-dir ${EXAMPLE_KERNEL_DIR}
|
||||
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
|
||||
--gpu-target ${GPU_TARGET}
|
||||
--jobs ${NPROC}
|
||||
--target-name ${NAME}
|
||||
COMMAND ${CMAKE_COMMAND} -E touch ${EXAMPLE_SENTINEL}
|
||||
DEPENDS ${EXAMPLE_SOURCE}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts
|
||||
COMMENT "[${NAME}] Generating and compiling kernels from DECL_KERNEL_SET..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_${NAME}_kernels DEPENDS ${EXAMPLE_SENTINEL})
|
||||
|
||||
# Add the executable
|
||||
add_executable(${NAME} ${SOURCE})
|
||||
|
||||
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
|
||||
|
||||
# Link against the per-example kernel library
|
||||
target_link_libraries(${NAME} PRIVATE ${EXAMPLE_LIB})
|
||||
|
||||
target_include_directories(${NAME} PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include
|
||||
${EXAMPLE_KERNEL_DIR}
|
||||
${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers
|
||||
)
|
||||
|
||||
# Force-include the generated registration header
|
||||
target_compile_options(${NAME} PRIVATE
|
||||
-include ${EXAMPLE_HEADER}
|
||||
-DGEMM_KERNEL_AVAILABLE=1
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
|
||||
# Only depends on generating THIS example's kernels
|
||||
add_dependencies(${NAME} generate_${NAME}_kernels)
|
||||
endfunction()
|
||||
|
||||
# GEMM C++ examples with declarative kernel support
|
||||
# Each example's C++ code contains DECL_KERNEL_SET which declares needed kernels
|
||||
add_declarative_gpu_example(gemm_01_basic gemm/cpp/01_basic_gemm.cpp)
|
||||
add_declarative_gpu_example(gemm_02_multi_size gemm/cpp/02_multi_size.cpp)
|
||||
add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_validation.cpp)
|
||||
add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp)
|
||||
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
|
||||
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
|
||||
|
||||
# =============================================================================
|
||||
# GEMM Python Library - Single Fallback Kernel
|
||||
# =============================================================================
|
||||
|
||||
# Generate a single fallback kernel for the Python library (fp16, rcr, compv4)
|
||||
set(GEMM_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/gemm_python_fallback")
|
||||
set(GEMM_FALLBACK_KERNEL "${GEMM_FALLBACK_KERNEL_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp")
|
||||
|
||||
# Tile config JSON for single kernel generation
|
||||
set(GEMM_FALLBACK_TILE_CONFIG "{\"tile_m\":[128],\"tile_n\":[128],\"tile_k\":[32],\"warp_m\":[2],\"warp_n\":[2],\"warp_k\":[1],\"warp_tile_m\":[32],\"warp_tile_n\":[32],\"warp_tile_k\":[16],\"pipeline\":[\"compv4\"],\"scheduler\":[\"intrawave\"],\"epilogue\":[\"cshuffle\"]}")
|
||||
|
||||
# Generate single fallback kernel (not all 6000+ kernels)
|
||||
add_custom_command(
|
||||
OUTPUT ${GEMM_FALLBACK_KERNEL}
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory ${GEMM_FALLBACK_KERNEL_DIR}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
|
||||
--datatype fp16 --layout rcr --variants standard
|
||||
--gpu-target ${GPU_TARGET}
|
||||
--output-dir ${GEMM_FALLBACK_KERNEL_DIR}
|
||||
--tile-config-json "${GEMM_FALLBACK_TILE_CONFIG}"
|
||||
COMMENT "Generating single fallback GEMM kernel for Python library"
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
add_custom_target(generate_gemm_fallback_kernel DEPENDS ${GEMM_FALLBACK_KERNEL})
|
||||
|
||||
# GEMM dynamic library for Python
|
||||
add_library(dispatcher_gemm_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/gemm_ctypes_lib.cpp)
|
||||
target_link_libraries(dispatcher_gemm_lib PRIVATE ck_tile_dispatcher)
|
||||
target_include_directories(dispatcher_gemm_lib PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../include
|
||||
${GEMM_FALLBACK_KERNEL_DIR}
|
||||
)
|
||||
target_compile_options(dispatcher_gemm_lib PRIVATE
|
||||
-DCK_TILE_SINGLE_KERNEL_INCLUDE
|
||||
-include ${GEMM_FALLBACK_KERNEL}
|
||||
-DGFX_ARCH="${GPU_TARGET}"
|
||||
-mllvm -enable-noalias-to-md-conversion=0
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
if(hip_FOUND)
|
||||
target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device hip::host)
|
||||
endif()
|
||||
add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel)
|
||||
|
||||
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
|
||||
|
||||
# Convenience target to build all Python ctypes libraries
|
||||
add_custom_target(python_libs
|
||||
DEPENDS dispatcher_gemm_lib
|
||||
COMMENT "Building Python ctypes libraries (GEMM)"
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Per-Architecture Kernel Generation Targets
|
||||
# =============================================================================
|
||||
|
||||
set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030)
|
||||
|
||||
foreach(ARCH ${SUPPORTED_GPU_ARCHS})
|
||||
# GEMM kernels for this arch
|
||||
add_custom_target(generate_gemm_kernels_${ARCH}
|
||||
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
|
||||
--datatype fp16 --layout rcr --gpu-target ${ARCH}
|
||||
--output ${KERNEL_OUTPUT_DIR}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
|
||||
COMMENT "Generating GEMM kernels for ${ARCH}..."
|
||||
VERBATIM
|
||||
)
|
||||
|
||||
# Alias for kernels (GEMM only now)
|
||||
add_custom_target(generate_kernels_${ARCH}
|
||||
DEPENDS generate_gemm_kernels_${ARCH}
|
||||
COMMENT "Generating all kernels for ${ARCH}..."
|
||||
)
|
||||
endforeach()
|
||||
|
||||
# =============================================================================
|
||||
# Summary
|
||||
# =============================================================================
|
||||
|
||||
message(STATUS "")
|
||||
message(STATUS "=== Dispatcher Examples Configuration ===")
|
||||
message(STATUS "")
|
||||
message(STATUS "Kernels will be generated automatically during 'make'")
|
||||
message(STATUS " Generated to: ${KERNEL_OUTPUT_DIR}")
|
||||
message(STATUS "")
|
||||
message(STATUS "Build targets:")
|
||||
message(STATUS " make - Build all examples (generates kernels first)")
|
||||
message(STATUS " make python_libs - Build Python ctypes libraries")
|
||||
message(STATUS " make generate_all_kernels - Generate all kernels only")
|
||||
message(STATUS " make regenerate_all_kernels - Force regenerate all kernels")
|
||||
message(STATUS "")
|
||||
message(STATUS "Per-architecture targets:")
|
||||
message(STATUS " make generate_kernels_<arch> - Generate for specific arch")
|
||||
message(STATUS " Supported archs: ${SUPPORTED_GPU_ARCHS}")
|
||||
message(STATUS "")
|
||||
210
dispatcher/examples/README.md
Normal file
210
dispatcher/examples/README.md
Normal file
@@ -0,0 +1,210 @@
|
||||
# CK Tile Dispatcher Examples
|
||||
|
||||
Comprehensive examples for GEMM operations with GPU execution.
|
||||
|
||||
> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference.
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Step 1: Build
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
mkdir -p build && cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DCMAKE_BUILD_TYPE=Release \
|
||||
-DGPU_TARGETS="gfx942" \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Build everything (C++ examples + Python libraries)
|
||||
make -j$(nproc)
|
||||
|
||||
# Or build ONLY Python libraries (faster)
|
||||
make python_libs -j$(nproc)
|
||||
```
|
||||
|
||||
### Step 2: Run C++ Examples
|
||||
|
||||
```bash
|
||||
cd build/examples
|
||||
|
||||
# GEMM
|
||||
./gemm_01_basic
|
||||
./gemm_02_multi_size
|
||||
./gemm_03_benchmark_validation
|
||||
./gemm_04_heuristics
|
||||
./gemm_05_json_export
|
||||
./gemm_06_multi_registry
|
||||
```
|
||||
|
||||
### Step 3: Run Python Examples
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
|
||||
# GEMM
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
python3 examples/gemm/python/04_validation.py
|
||||
python3 examples/gemm/python/07_stress_test.py
|
||||
python3 examples/gemm/python/08_heuristics.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
examples/
|
||||
├── gemm/
|
||||
│ ├── cpp/ # 6 C++ GEMM examples
|
||||
│ └── python/ # 11 Python GEMM examples
|
||||
│
|
||||
└── README.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## GEMM Examples
|
||||
|
||||
### C++ Examples
|
||||
|
||||
| # | Example | Description |
|
||||
|---|---------|-------------|
|
||||
| 01 | `gemm_01_basic` | Basic GEMM with declarative API, autofill, autocorrect |
|
||||
| 02 | `gemm_02_multi_size` | Wildcard expansion for multiple configurations |
|
||||
| 03 | `gemm_03_benchmark_validation` | Performance benchmarking with CPU/GPU validation |
|
||||
| 04 | `gemm_04_heuristics` | Heuristic-based kernel selection |
|
||||
| 05 | `gemm_05_json_export` | Registry JSON export for external tools |
|
||||
| 06 | `gemm_06_multi_registry` | Multiple registries with named kernel sets |
|
||||
|
||||
**Details:** [gemm/cpp/README.md](gemm/cpp/README.md)
|
||||
|
||||
---
|
||||
|
||||
### Python Examples
|
||||
|
||||
| # | Example | Description |
|
||||
|---|---------|-------------|
|
||||
| 01 | `01_basic_gemm.py` | Basic GEMM with multi-kernel support |
|
||||
| 02 | `02_batch_gemm.py` | Batched GEMM operations |
|
||||
| 03 | `03_benchmark.py` | Performance benchmarking |
|
||||
| 04 | `04_validation.py` | CPU reference validation |
|
||||
| 05 | `05_numpy_integration.py` | NumPy array integration |
|
||||
| 06 | `06_json_export.py` | Registry JSON export |
|
||||
| 07 | `07_stress_test.py` | Multi-kernel stress testing (48 configs) |
|
||||
| 08 | `08_heuristics.py` | Heuristic-based kernel selection (24 configs) |
|
||||
| 09 | `09_multi_registry.py` | Multiple registries |
|
||||
| 10 | `10_advanced_benchmark.py` | Advanced benchmark with full control |
|
||||
| 11 | `11_json_import.py` | Import kernels from JSON |
|
||||
|
||||
**Details:** [gemm/python/README.md](gemm/python/README.md)
|
||||
|
||||
---
|
||||
|
||||
## Key Features
|
||||
|
||||
### Declarative Kernel API
|
||||
|
||||
Both C++ and Python examples use a declarative approach:
|
||||
|
||||
**C++ (DECL_KERNEL_SET macro):**
|
||||
```cpp
|
||||
DECL_KERNEL_SET(my_kernels,
|
||||
.add(
|
||||
Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
|
||||
.pipeline("compv4").scheduler("intrawave"),
|
||||
"gfx942"
|
||||
)
|
||||
);
|
||||
```
|
||||
|
||||
**Python (KernelConfig):**
|
||||
```python
|
||||
config = KernelConfig(
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv4", scheduler="intrawave"
|
||||
)
|
||||
```
|
||||
|
||||
### Autofill and Autocorrect
|
||||
|
||||
The build system automatically:
|
||||
- **Autofills** missing parameters with sensible defaults
|
||||
- **Autocorrects** invalid parameters based on architecture constraints
|
||||
- **Expands** wildcards (`*`, `-1`, `ANY_INT`) to all valid configurations
|
||||
|
||||
### Architecture Filtering
|
||||
|
||||
Kernel configurations are validated against GPU architecture constraints:
|
||||
- Tile divisibility requirements
|
||||
- Warp tile constraints
|
||||
- Pipeline compatibility
|
||||
|
||||
Invalid configurations are automatically pruned during code generation.
|
||||
|
||||
---
|
||||
|
||||
## Validation Examples
|
||||
|
||||
### C++ Validation
|
||||
|
||||
```bash
|
||||
./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference
|
||||
./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference
|
||||
```
|
||||
|
||||
### Python Validation
|
||||
|
||||
```bash
|
||||
python3 examples/gemm/python/04_validation.py
|
||||
python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Python: Library not found
|
||||
|
||||
```bash
|
||||
# Run from dispatcher directory
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
```
|
||||
|
||||
### C++: Executables not found
|
||||
|
||||
```bash
|
||||
# Build with examples enabled
|
||||
cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
make -j$(nproc)
|
||||
|
||||
# Run from build/examples
|
||||
cd build/examples
|
||||
./gemm_01_basic
|
||||
```
|
||||
|
||||
### GPU not detected
|
||||
|
||||
```bash
|
||||
rocminfo | grep "Name:"
|
||||
# Should show: gfx942, gfx90a, etc.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Archived Examples
|
||||
|
||||
Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`:
|
||||
- `examples/conv/cpp/` - 11 C++ convolution examples
|
||||
- `examples/conv/python/` - 14 Python convolution examples
|
||||
|
||||
See the archive for convolution functionality reference.
|
||||
243
dispatcher/examples/gemm/cpp/01_basic_gemm.cpp
Normal file
243
dispatcher/examples/gemm/cpp/01_basic_gemm.cpp
Normal file
@@ -0,0 +1,243 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 01: Basic GEMM - Autofill, Autocorrect, and Full Declaration
|
||||
*
|
||||
* Demonstrates THREE declaration patterns:
|
||||
*
|
||||
* 1. AUTOFILL: Minimal declaration - missing params filled with defaults
|
||||
* .add(Signature().dtype("fp16").layout("rcr"),
|
||||
* Algorithm().tile(128,128,64).pipeline("compv3").scheduler("intrawave"),
|
||||
* "gfx942")
|
||||
* -> wave(2,2,1), warp(32,32,16), epilogue("cshuffle") added automatically
|
||||
*
|
||||
* 2. AUTOCORRECT: Invalid params corrected to valid values
|
||||
* .add(..., Algorithm().wave(1,1,1)...)
|
||||
* -> wave(1,1,1) is invalid for gfx942, corrected to wave(2,2,1)
|
||||
*
|
||||
* 3. FULL: All parameters explicitly specified
|
||||
* .add(..., Algorithm().tile().wave().warp().pipeline().scheduler().epilogue()...)
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_01_basic
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::backends;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// THREE KERNEL DECLARATION PATTERNS
|
||||
// =============================================================================
|
||||
|
||||
DECL_KERNEL_SET(
|
||||
basic_gemm_kernels,
|
||||
// -------------------------------------------------------------------------
|
||||
// Pattern 1: AUTOFILL - Minimal declaration
|
||||
// Only specify: dtype, layout, tile, pipeline, scheduler
|
||||
// Auto-filled: wave(2,2,1), warp(32,32,16), epilogue("cshuffle"), pad(false,false,false)
|
||||
// -------------------------------------------------------------------------
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64) // Required
|
||||
.pipeline("compv3") // Required
|
||||
.scheduler("intrawave"), // Required
|
||||
"gfx942")
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Pattern 2: AUTOCORRECT - Invalid wave config
|
||||
// wave(1,1,1) is invalid for gfx942 WMMA, corrected to wave(2,2,1)
|
||||
// -------------------------------------------------------------------------
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 32) // Different tile_k to make unique kernel
|
||||
.wave(1, 1, 1) // INVALID: autocorrected to (2,2,1)
|
||||
.warp(32, 32, 16) // Valid warp for 128x128 tile
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Pattern 3: FULL - All parameters explicitly specified
|
||||
// No autofill or autocorrect needed
|
||||
// -------------------------------------------------------------------------
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32) // Explicit tile
|
||||
.wave(2, 2, 1) // Explicit wave (valid)
|
||||
.warp(16, 16, 32) // Explicit warp tile
|
||||
.pipeline("compv3") // Explicit pipeline
|
||||
.scheduler("intrawave") // Explicit scheduler
|
||||
.epilogue("cshuffle") // Explicit epilogue
|
||||
.pad(false, false, false), // Explicit padding
|
||||
"gfx942"));
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 01: GEMM Autofill/Autocorrect/Full",
|
||||
"Three kernel declaration patterns");
|
||||
args.add_flag("--list", "List registered kernels");
|
||||
args.add_flag("--list-verbose", "List registered kernels with full details");
|
||||
args.add_option("--size", "1024", "Problem size MxNxK");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 01: GEMM Declaration Patterns");
|
||||
|
||||
// =========================================================================
|
||||
// Show the Three Patterns
|
||||
// =========================================================================
|
||||
std::cout << "\nTHREE DECLARATION PATTERNS:\n";
|
||||
std::cout << "============================\n\n";
|
||||
|
||||
std::cout << "1. AUTOFILL (minimal declaration):\n";
|
||||
std::cout << " .add(Signature().dtype(\"fp16\").layout(\"rcr\"),\n";
|
||||
std::cout
|
||||
<< " Algorithm().tile(128,128,64).pipeline(\"compv3\").scheduler(\"intrawave\"),\n";
|
||||
std::cout << " \"gfx942\")\n";
|
||||
std::cout << " -> Auto-filled: wave(2,2,1), warp(32,32,16), epilogue(\"cshuffle\")\n\n";
|
||||
|
||||
std::cout << "2. AUTOCORRECT (invalid params fixed):\n";
|
||||
std::cout << " .add(..., Algorithm().wave(1,1,1)...)\n";
|
||||
std::cout << " -> wave(1,1,1) invalid for gfx942, corrected to wave(2,2,1)\n\n";
|
||||
|
||||
std::cout << "3. FULL (all params explicit):\n";
|
||||
std::cout << " .add(..., "
|
||||
"Algorithm().tile().wave().warp().pipeline().scheduler().epilogue().pad()...)\n";
|
||||
std::cout << " -> No changes needed\n\n";
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
// =========================================================================
|
||||
// Step 1: Show Declared Kernel Sets
|
||||
// =========================================================================
|
||||
std::cout << "Step 1: Declared Kernel Sets\n";
|
||||
KernelSetRegistry::instance().print();
|
||||
|
||||
const auto& decl_set = KernelSetRegistry::instance().get("basic_gemm_kernels");
|
||||
std::cout << " 'basic_gemm_kernels': " << decl_set.size() << " declaration(s)\n";
|
||||
|
||||
// =========================================================================
|
||||
// Step 2: Create Registry and Register Kernels
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Register Kernels\n";
|
||||
|
||||
Registry registry;
|
||||
// Use generic macro
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
|
||||
std::cout << " Registered " << registry.size() << " kernel(s)\n";
|
||||
|
||||
// List kernels if requested
|
||||
if(args.has("--list") || args.has("--list-verbose"))
|
||||
{
|
||||
std::cout << "\n";
|
||||
print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
|
||||
return 0;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 3: Create Dispatcher
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 3: Create Dispatcher\n";
|
||||
Dispatcher dispatcher(®istry);
|
||||
|
||||
// =========================================================================
|
||||
// Step 4: Setup Problem
|
||||
// =========================================================================
|
||||
int size = args.get_int("--size", 1024);
|
||||
const int M = size, N = size, K = size;
|
||||
|
||||
std::cout << "\nStep 4: Setup Problem (" << M << "x" << N << "x" << K << ")\n";
|
||||
|
||||
Problem problem(M, N, K);
|
||||
|
||||
using DataType = ck_tile::fp16_t;
|
||||
GpuBuffer<DataType> a_dev(M * K);
|
||||
GpuBuffer<DataType> b_dev(K * N);
|
||||
GpuBuffer<DataType> c_dev(M * N);
|
||||
|
||||
std::vector<DataType> a_host(M * K, DataType(1.0f));
|
||||
std::vector<DataType> b_host(K * N, DataType(1.0f));
|
||||
a_dev.copy_from_host(a_host.data());
|
||||
b_dev.copy_from_host(b_host.data());
|
||||
c_dev.zero();
|
||||
|
||||
// =========================================================================
|
||||
// Step 5: Select and Run
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 5: Select and Run\n";
|
||||
|
||||
auto selected = dispatcher.select_kernel(problem);
|
||||
if(!selected)
|
||||
{
|
||||
std::cerr << "ERROR: No kernel found!\n";
|
||||
return 1;
|
||||
}
|
||||
std::cout << " Selected: " << selected->get_name() << "\n";
|
||||
|
||||
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n";
|
||||
|
||||
// =========================================================================
|
||||
// Step 6: Verify
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 6: Verify\n";
|
||||
std::vector<DataType> c_host(M * N);
|
||||
c_dev.copy_to_host(c_host.data());
|
||||
|
||||
const float expected = static_cast<float>(K);
|
||||
int errors = 0;
|
||||
for(int i = 0; i < M * N; ++i)
|
||||
{
|
||||
if(std::abs(static_cast<float>(c_host[i]) - expected) > 0.01f * expected + 1.0f)
|
||||
++errors;
|
||||
}
|
||||
|
||||
bool passed = (errors == 0);
|
||||
std::cout << " Expected: " << expected << ", Errors: " << errors << "\n";
|
||||
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
|
||||
|
||||
// =========================================================================
|
||||
// Summary
|
||||
// =========================================================================
|
||||
print_separator();
|
||||
std::cout << "DECLARATION PATTERNS SUMMARY:\n";
|
||||
print_separator();
|
||||
std::cout << R"(
|
||||
1. AUTOFILL: Specify only required params, system fills defaults
|
||||
- Useful for quick prototyping
|
||||
- Guarantees valid configuration
|
||||
|
||||
2. AUTOCORRECT: System validates and fixes invalid params
|
||||
- wave(1,1,1) -> wave(2,2,1) on gfx942
|
||||
- Invalid pipeline/scheduler combos fixed
|
||||
- Logs corrections for debugging
|
||||
|
||||
3. FULL: All params explicit - no changes made
|
||||
- Full control over configuration
|
||||
- Best for production/tuning
|
||||
)";
|
||||
print_separator();
|
||||
|
||||
return passed ? 0 : 1;
|
||||
}
|
||||
215
dispatcher/examples/gemm/cpp/02_multi_size.cpp
Normal file
215
dispatcher/examples/gemm/cpp/02_multi_size.cpp
Normal file
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 02: Multi-Size GEMM with Wildcard Expansion
|
||||
*
|
||||
* Demonstrates the WILDCARD feature where specifying wildcards causes
|
||||
* the build system to expand to ALL valid configurations for the architecture.
|
||||
*
|
||||
* WILDCARD SYNTAX:
|
||||
* - Integer params: ANY_INT or -1 (both are equivalent, ANY_INT is just a #define for -1)
|
||||
* - String params: "*" (for pipeline, scheduler)
|
||||
*
|
||||
* The kernel declaration:
|
||||
* .add(..., Algorithm().tile(64,64,64).wave(ANY_INT,ANY_INT,1).warp(-1,-1,-1)
|
||||
* .pipeline("*").scheduler("*"), ...)
|
||||
*
|
||||
* Expands to multiple kernels:
|
||||
* - wave: (1,4,1), (2,2,1), (4,1,1) -> 3 options
|
||||
* - warp: (16,16,32), (32,32,16) -> 2 options
|
||||
* - pipeline: "compv3" -> 1 option (compv4 requires special handling)
|
||||
* - scheduler: "intrawave" -> 1 option
|
||||
*
|
||||
* Raw expansion: 3 × 2 = 6 configs, but arch filter validates each:
|
||||
* - tile_m must be divisible by (warp_m × warp_tile_m)
|
||||
* - tile_n must be divisible by (warp_n × warp_tile_n)
|
||||
* - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16)
|
||||
* Result: 4 valid wildcard kernels + 1 explicit = 5 total
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_02_multi_size
|
||||
* Usage: ./gemm_02_multi_size [--max-size N] [--help]
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// KERNEL SET: Demonstrates Wildcard Expansion
|
||||
// =============================================================================
|
||||
|
||||
DECL_KERNEL_SET(multi_size_kernels,
|
||||
// -------------------------------------------------------------------------
|
||||
// Kernel 1: Explicit - all parameters specified (no expansion)
|
||||
// -------------------------------------------------------------------------
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(16, 16, 32)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// Kernel 2: WILDCARD - expands to multiple valid configurations
|
||||
// Wildcards: ANY_INT == -1 (for integers), "*" (for strings)
|
||||
// -------------------------------------------------------------------------
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 64)
|
||||
.wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1)
|
||||
.warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16)
|
||||
.pipeline("*") // "*" → valid pipelines
|
||||
.scheduler("*") // "*" → valid schedulers
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 02: Multi-Size GEMM with Wildcards",
|
||||
"Demonstrates wildcard expansion for kernel generation");
|
||||
args.add_option("--max-size", "4096", "Maximum problem size to test");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
args.add_flag("--list", "List all registered kernels");
|
||||
args.add_flag("--list-verbose", "List kernels with full configuration details");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
int max_size = args.get_int("--max-size", 4096);
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
print_header("Example 02: Multi-Size GEMM with Wildcards");
|
||||
|
||||
// =========================================================================
|
||||
// Show Wildcard Expansion Concept
|
||||
// =========================================================================
|
||||
std::cout << "\nWILDCARD EXPANSION:\n";
|
||||
std::cout << "===================\n";
|
||||
std::cout << R"(
|
||||
Wildcard syntax:
|
||||
ANY_INT or -1 -> expands integer params to all valid values
|
||||
"*" -> expands string params (pipeline/scheduler) to valid values
|
||||
|
||||
Declaration with wildcards:
|
||||
.tile(64, 64, 64) -> fixed tile size (no wildcard)
|
||||
.wave(ANY_INT, ANY_INT, 1) -> expands to (1,4,1), (2,2,1), (4,1,1) = 3
|
||||
.warp(-1, -1, -1) -> expands to (16,16,32), (32,32,16) = 2
|
||||
.pipeline("*") -> expands to valid pipelines = 1
|
||||
.scheduler("*") -> expands to valid schedulers = 1
|
||||
|
||||
Expanded: 3 × 2 = 6 configs, but arch filter validates each:
|
||||
- wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64
|
||||
- Result: 4 valid kernels from wildcard + 1 explicit = 5 total
|
||||
)";
|
||||
|
||||
// =========================================================================
|
||||
// Setup Registry and Dispatcher
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 1: Register Kernels\n";
|
||||
std::cout << "------------------------\n";
|
||||
|
||||
Registry registry;
|
||||
registry.set_name("multi_size_registry");
|
||||
|
||||
// Register kernels from generated header (includes expanded wildcards)
|
||||
// Use generic macro - no need to hardcode example name
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
std::cout << " Registered " << registry.size() << " kernel(s) from wildcard expansion\n";
|
||||
|
||||
if(args.has("--list") || args.has("--list-verbose"))
|
||||
{
|
||||
std::cout << "\n";
|
||||
print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
|
||||
return 0;
|
||||
}
|
||||
|
||||
Dispatcher dispatcher(®istry);
|
||||
std::cout << " Max size: " << max_size << "\n";
|
||||
|
||||
// =========================================================================
|
||||
// Run Multiple Problem Sizes
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Run Multiple Sizes\n";
|
||||
print_separator();
|
||||
std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K"
|
||||
<< std::setw(12) << "Time(ms)" << std::setw(12) << "TFLOPS" << "\n";
|
||||
print_separator();
|
||||
|
||||
std::vector<std::tuple<int, int, int>> all_sizes = {
|
||||
{256, 256, 256},
|
||||
{512, 512, 512},
|
||||
{1024, 1024, 1024},
|
||||
{2048, 2048, 2048},
|
||||
{4096, 4096, 4096},
|
||||
};
|
||||
|
||||
std::vector<std::tuple<int, int, int>> sizes;
|
||||
for(const auto& [M, N, K] : all_sizes)
|
||||
{
|
||||
if(std::max({M, N, K}) <= max_size)
|
||||
sizes.push_back({M, N, K});
|
||||
}
|
||||
|
||||
using DataType = ck_tile::fp16_t;
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& [M, N, K] : sizes)
|
||||
{
|
||||
Problem problem(M, N, K);
|
||||
|
||||
GpuBuffer<DataType> a_dev(M * K);
|
||||
GpuBuffer<DataType> b_dev(K * N);
|
||||
GpuBuffer<DataType> c_dev(M * N);
|
||||
|
||||
std::vector<DataType> a_host(M * K, DataType(1.0f));
|
||||
std::vector<DataType> b_host(K * N, DataType(1.0f));
|
||||
a_dev.copy_from_host(a_host.data());
|
||||
b_dev.copy_from_host(b_host.data());
|
||||
c_dev.zero();
|
||||
|
||||
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
|
||||
double tflops = calculate_tflops(M, N, K, time_ms);
|
||||
|
||||
std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12)
|
||||
<< std::fixed << std::setprecision(4) << time_ms << std::setw(12)
|
||||
<< std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Verify
|
||||
std::vector<DataType> c_host(M * N);
|
||||
c_dev.copy_to_host(c_host.data());
|
||||
float expected = static_cast<float>(K);
|
||||
int errors = 0;
|
||||
for(int i = 0; i < M * N; ++i)
|
||||
{
|
||||
if(std::abs(static_cast<float>(c_host[i]) - expected) > 0.01f * expected + 1.0f)
|
||||
++errors;
|
||||
}
|
||||
if(errors > 0)
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
print_separator();
|
||||
std::cout << "Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
|
||||
print_separator();
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
344
dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp
Normal file
344
dispatcher/examples/gemm/cpp/03_benchmark_validation.cpp
Normal file
@@ -0,0 +1,344 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 03: GEMM Benchmark & Validation
|
||||
*
|
||||
* Combined example demonstrating:
|
||||
* 1. Benchmarking with statistics (warmup, iterations, min/max/mean/median)
|
||||
* 2. Validation against CK Tile reference (CPU or GPU)
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_03_benchmark_validation
|
||||
* Usage: ./gemm_03_benchmark_validation [--size N] [--verify MODE] [--benchmark]
|
||||
*
|
||||
* Options:
|
||||
* --size N Problem size MxNxK (default: 512)
|
||||
* --verify MODE 0=none, 1=CPU ref, 2=GPU ref (default: 1)
|
||||
* --benchmark Run full benchmark with statistics
|
||||
* --warmup N Warmup iterations (default: 5)
|
||||
* --iterations N Benchmark iterations (default: 20)
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <cmath>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using namespace ck_tile::literals;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// KERNEL SET: High-performance kernels for benchmarking/validation
|
||||
// =============================================================================
|
||||
|
||||
DECL_KERNEL_SET(benchmark_validation_kernels,
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// =============================================================================
|
||||
// Helper: Layout detection
|
||||
// =============================================================================
|
||||
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 03: GEMM Benchmark & Validation",
|
||||
"Benchmark and/or validate GEMM output against reference");
|
||||
args.add_option("--size", "512", "Problem size MxNxK");
|
||||
args.add_option("--verify", "1", "Verification: 0=none, 1=CPU ref, 2=GPU ref");
|
||||
args.add_flag("--benchmark", "Run benchmark with statistics");
|
||||
args.add_option("--warmup", "5", "Warmup iterations");
|
||||
args.add_option("--iterations", "20", "Benchmark iterations");
|
||||
args.add_option("--rtol", "0.01", "Relative tolerance");
|
||||
args.add_option("--atol", "0.01", "Absolute tolerance");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
int M = args.get_int("--size", 512);
|
||||
int N = M;
|
||||
int K = M;
|
||||
int verify = args.get_int("--verify", 1);
|
||||
bool do_benchmark = args.has("--benchmark");
|
||||
int warmup = args.get_int("--warmup", 5);
|
||||
int iterations = args.get_int("--iterations", 20);
|
||||
float rtol = args.get_float("--rtol", 0.01f);
|
||||
float atol = args.get_float("--atol", 0.01f);
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
print_header("Example 03: GEMM Benchmark & Validation");
|
||||
|
||||
std::cout << "\nConfiguration:\n";
|
||||
std::cout << " Problem: " << M << " x " << N << " x " << K << "\n";
|
||||
std::cout << " Layout: RCR (A=row, B=col, C=row)\n";
|
||||
std::cout << " Verify: " << verify;
|
||||
if(verify == 0)
|
||||
std::cout << " (disabled)";
|
||||
else if(verify == 1)
|
||||
std::cout << " (CPU reference)";
|
||||
else if(verify == 2)
|
||||
std::cout << " (GPU reference)";
|
||||
std::cout << "\n";
|
||||
std::cout << " Benchmark: " << (do_benchmark ? "yes" : "no") << "\n";
|
||||
if(do_benchmark)
|
||||
{
|
||||
std::cout << " Warmup: " << warmup << " iterations\n";
|
||||
std::cout << " Measure: " << iterations << " iterations\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Setup Registry and Dispatcher
|
||||
// =========================================================================
|
||||
Registry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
Dispatcher dispatcher(®istry);
|
||||
|
||||
std::cout << " Kernels: " << registry.size() << " registered\n";
|
||||
print_registered_kernels(registry);
|
||||
|
||||
// =========================================================================
|
||||
// Initialize data with proper tensor descriptors
|
||||
// =========================================================================
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using BDataType = ck_tile::fp16_t;
|
||||
using CDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
|
||||
auto stride_a = ck_tile::get_default_stride(M, K, 0_uz, is_row_major(ALayout{}));
|
||||
auto stride_b = ck_tile::get_default_stride(K, N, 0_uz, is_row_major(BLayout{}));
|
||||
auto stride_c = ck_tile::get_default_stride(M, N, 0_uz, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_a, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_b, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{})));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
|
||||
std::cout << "\nData:\n";
|
||||
std::cout << " A: " << M << " x " << K << " (fp16, row-major)\n";
|
||||
std::cout << " B: " << K << " x " << N << " (fp16, col-major)\n";
|
||||
std::cout << " C: " << M << " x " << N << " (fp16, row-major)\n";
|
||||
|
||||
// GPU memory
|
||||
ck_tile::DeviceMem a_dev(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_dev(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev(c_m_n_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev.ToDevice(a_m_k.data());
|
||||
b_dev.ToDevice(b_k_n.data());
|
||||
|
||||
// =========================================================================
|
||||
// Compute Reference (if needed)
|
||||
// =========================================================================
|
||||
if(verify > 0)
|
||||
{
|
||||
std::cout << "\nComputing reference...\n";
|
||||
c_m_n_ref.SetZero();
|
||||
|
||||
if(verify == 1)
|
||||
{
|
||||
std::cout << " Using CPU reference (ck_tile::reference_gemm)\n";
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
std::cout << " Using GPU reference (ck_tile::reference_gemm_gpu)\n";
|
||||
ck_tile::DeviceMem c_ref_dev(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_ref_dev.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_ref_dev.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c);
|
||||
|
||||
(void)hipDeviceSynchronize();
|
||||
c_ref_dev.FromDevice(c_m_n_ref.data());
|
||||
}
|
||||
std::cout << " Reference complete.\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Run Kernel
|
||||
// =========================================================================
|
||||
Problem problem(M, N, K);
|
||||
auto selected = dispatcher.select_kernel(problem);
|
||||
|
||||
std::cout << "\nRunning kernel:\n";
|
||||
if(selected)
|
||||
std::cout << " Selected: " << selected->get_name() << "\n";
|
||||
|
||||
c_dev.SetZero();
|
||||
float time_ms = 0.0f;
|
||||
std::vector<float> times;
|
||||
|
||||
if(do_benchmark)
|
||||
{
|
||||
// Warmup
|
||||
std::cout << " Warming up (" << warmup << " iterations)...\n";
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
c_dev.SetZero();
|
||||
(void)dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
|
||||
problem,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
// Benchmark
|
||||
std::cout << " Benchmarking (" << iterations << " iterations)...\n";
|
||||
times.reserve(iterations);
|
||||
for(int i = 0; i < iterations; ++i)
|
||||
{
|
||||
c_dev.SetZero();
|
||||
float t = dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
|
||||
problem,
|
||||
nullptr);
|
||||
times.push_back(t);
|
||||
}
|
||||
time_ms = *std::min_element(times.begin(), times.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
// Single run
|
||||
time_ms = dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
|
||||
problem,
|
||||
nullptr);
|
||||
}
|
||||
|
||||
c_dev.FromDevice(c_m_n_dev.data());
|
||||
|
||||
// =========================================================================
|
||||
// Results
|
||||
// =========================================================================
|
||||
double flops = 2.0 * M * N * K;
|
||||
double tflops = flops / (time_ms * 1e9);
|
||||
|
||||
print_separator();
|
||||
std::cout << "Performance:\n";
|
||||
print_separator();
|
||||
|
||||
if(do_benchmark && !times.empty())
|
||||
{
|
||||
std::sort(times.begin(), times.end());
|
||||
float min_t = times.front();
|
||||
float max_t = times.back();
|
||||
float median_t = times[times.size() / 2];
|
||||
float mean_t = std::accumulate(times.begin(), times.end(), 0.0f) / times.size();
|
||||
|
||||
std::cout << std::fixed << std::setprecision(4);
|
||||
std::cout << " Min: " << min_t << " ms (" << std::setprecision(2)
|
||||
<< (flops / (min_t * 1e9)) << " TFLOPS)\n";
|
||||
std::cout << std::setprecision(4);
|
||||
std::cout << " Max: " << max_t << " ms\n";
|
||||
std::cout << " Mean: " << mean_t << " ms (" << std::setprecision(2)
|
||||
<< (flops / (mean_t * 1e9)) << " TFLOPS)\n";
|
||||
std::cout << std::setprecision(4);
|
||||
std::cout << " Median: " << median_t << " ms (" << std::setprecision(2)
|
||||
<< (flops / (median_t * 1e9)) << " TFLOPS)\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << std::fixed << std::setprecision(4);
|
||||
std::cout << " Time: " << time_ms << " ms\n";
|
||||
std::cout << std::setprecision(2);
|
||||
std::cout << " TFLOPS: " << tflops << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Validation
|
||||
// =========================================================================
|
||||
bool pass = true;
|
||||
|
||||
if(verify > 0)
|
||||
{
|
||||
print_separator();
|
||||
std::cout << "Validation:\n";
|
||||
print_separator();
|
||||
std::cout << " Tolerance: rtol=" << rtol << ", atol=" << atol << "\n";
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_dev, c_m_n_ref, "Validation Error!", rtol, atol);
|
||||
|
||||
float max_abs_diff = 0.0f;
|
||||
float max_rel_diff = 0.0f;
|
||||
for(size_t i = 0; i < c_m_n_dev.get_element_space_size(); ++i)
|
||||
{
|
||||
float dev_val = static_cast<float>(c_m_n_dev.mData[i]);
|
||||
float ref_val = static_cast<float>(c_m_n_ref.mData[i]);
|
||||
float abs_diff = std::abs(dev_val - ref_val);
|
||||
float rel_diff = (ref_val != 0.0f) ? abs_diff / std::abs(ref_val) : abs_diff;
|
||||
max_abs_diff = std::max(max_abs_diff, abs_diff);
|
||||
max_rel_diff = std::max(max_rel_diff, rel_diff);
|
||||
}
|
||||
|
||||
std::cout << " Max abs diff: " << max_abs_diff << "\n";
|
||||
std::cout << " Max rel diff: " << max_rel_diff << "\n";
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Summary
|
||||
// =========================================================================
|
||||
print_separator();
|
||||
std::cout << "Result: " << (pass ? "PASS" : "FAIL") << "\n";
|
||||
print_separator();
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
168
dispatcher/examples/gemm/cpp/04_heuristics.cpp
Normal file
168
dispatcher/examples/gemm/cpp/04_heuristics.cpp
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 04: Custom Heuristics
|
||||
*
|
||||
* Demonstrates custom kernel selection heuristics for different workloads.
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// KERNEL SET: Multiple tile sizes for heuristic-based selection
|
||||
// =============================================================================
|
||||
|
||||
DECL_KERNEL_SET(heuristics_kernels,
|
||||
// Small tile - low latency
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
// Medium tile - balanced
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// =============================================================================
|
||||
// Custom Heuristic
|
||||
// =============================================================================
|
||||
|
||||
std::vector<std::string> size_based_heuristic(const Problem& problem)
|
||||
{
|
||||
std::vector<std::string> ranked_kernels;
|
||||
int64_t total_elements = problem.M * problem.N;
|
||||
|
||||
if(total_elements < 100000)
|
||||
{
|
||||
ranked_kernels = {"gemm_64x64", "gemm_128x128"};
|
||||
}
|
||||
else
|
||||
{
|
||||
ranked_kernels = {"gemm_128x128", "gemm_64x64"};
|
||||
}
|
||||
return ranked_kernels;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 04: Custom Heuristics",
|
||||
"Demonstrates custom kernel selection heuristics");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 04: Custom Heuristics");
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
// =========================================================================
|
||||
// Setup Registry and Dispatcher
|
||||
// =========================================================================
|
||||
Registry registry;
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
|
||||
Dispatcher dispatcher(®istry);
|
||||
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
|
||||
dispatcher.set_heuristic(size_based_heuristic);
|
||||
|
||||
std::cout << "\nSetup:\n";
|
||||
std::cout << " Registry: " << registry.size() << " kernel(s)\n";
|
||||
std::cout << " Strategy: Heuristic (size-based)\n";
|
||||
|
||||
// =========================================================================
|
||||
// Test Different Problem Sizes
|
||||
// =========================================================================
|
||||
std::cout << "\nTesting heuristic selection:\n";
|
||||
print_separator();
|
||||
|
||||
using DataType = ck_tile::fp16_t;
|
||||
|
||||
std::vector<std::tuple<int, int, int>> sizes = {
|
||||
{128, 128, 64},
|
||||
{512, 512, 256},
|
||||
{2048, 2048, 1024},
|
||||
};
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& [M, N, K] : sizes)
|
||||
{
|
||||
Problem problem(M, N, K);
|
||||
auto selected = dispatcher.select_kernel(problem);
|
||||
|
||||
std::cout << "Problem " << M << "x" << N << "x" << K << ":\n";
|
||||
if(selected)
|
||||
{
|
||||
std::cout << " Selected: " << selected->get_name() << "\n";
|
||||
}
|
||||
|
||||
GpuBuffer<DataType> a_dev(M * K);
|
||||
GpuBuffer<DataType> b_dev(K * N);
|
||||
GpuBuffer<DataType> c_dev(M * N);
|
||||
|
||||
std::vector<DataType> a_host(M * K, DataType(1.0f));
|
||||
std::vector<DataType> b_host(K * N, DataType(1.0f));
|
||||
a_dev.copy_from_host(a_host.data());
|
||||
b_dev.copy_from_host(b_host.data());
|
||||
c_dev.zero();
|
||||
|
||||
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
|
||||
double tflops = calculate_tflops(M, N, K, time_ms);
|
||||
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Verify
|
||||
std::vector<DataType> c_host(M * N);
|
||||
c_dev.copy_to_host(c_host.data());
|
||||
float expected = static_cast<float>(K);
|
||||
int errors = 0;
|
||||
for(int i = 0; i < M * N; ++i)
|
||||
{
|
||||
float actual = static_cast<float>(c_host[i]);
|
||||
if(std::abs(actual - expected) > 0.01f * expected + 1.0f)
|
||||
++errors;
|
||||
}
|
||||
bool pass = (errors == 0);
|
||||
std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n";
|
||||
if(!pass)
|
||||
all_passed = false;
|
||||
print_separator();
|
||||
}
|
||||
|
||||
std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
127
dispatcher/examples/gemm/cpp/05_json_export.cpp
Normal file
127
dispatcher/examples/gemm/cpp/05_json_export.cpp
Normal file
@@ -0,0 +1,127 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 05: JSON Export
|
||||
*
|
||||
* Demonstrates exporting registry information to JSON format.
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_05_json_export
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// KERNEL SET: Multiple kernels for JSON export demo
|
||||
// =============================================================================
|
||||
|
||||
DECL_KERNEL_SET(json_export_kernels,
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 05: JSON Export", "Export registry information to JSON format");
|
||||
args.add_option("--output", "registry.json", "Output JSON file path");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
args.add_flag("--list", "List all kernel sets");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 05: JSON Export");
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
if(args.has("--list"))
|
||||
{
|
||||
std::cout << "\nDeclared Kernel Sets:\n";
|
||||
KernelSetRegistry::instance().print();
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string output_file = args.get("--output", "registry.json");
|
||||
|
||||
// =========================================================================
|
||||
// Setup Registry
|
||||
// =========================================================================
|
||||
std::cout << "\nSetting up registry...\n";
|
||||
Registry registry;
|
||||
registry.set_name("json_export_registry");
|
||||
|
||||
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
|
||||
|
||||
std::cout << " Registry: " << registry.get_name() << "\n";
|
||||
std::cout << " Kernels: " << registry.size() << "\n";
|
||||
|
||||
// =========================================================================
|
||||
// Export to JSON
|
||||
// =========================================================================
|
||||
std::cout << "\nExporting to JSON...\n";
|
||||
|
||||
std::string json = registry.export_json(true);
|
||||
|
||||
std::cout << "\nJSON Preview (first 500 chars):\n";
|
||||
print_separator();
|
||||
std::cout << json.substr(0, std::min(size_t(500), json.size()));
|
||||
if(json.size() > 500)
|
||||
std::cout << "\n...";
|
||||
std::cout << "\n";
|
||||
print_separator();
|
||||
|
||||
// Write to file
|
||||
std::ofstream file(output_file);
|
||||
if(file.is_open())
|
||||
{
|
||||
file << json;
|
||||
file.close();
|
||||
std::cout << "\nExported to: " << output_file << "\n";
|
||||
std::cout << "File size: " << json.size() << " bytes\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Failed to write to: " << output_file << "\n";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Also show kernel set declarations
|
||||
// =========================================================================
|
||||
std::cout << "\nKernel Set Declarations:\n";
|
||||
print_separator();
|
||||
KernelSetRegistry::instance().print();
|
||||
print_separator();
|
||||
|
||||
return 0;
|
||||
}
|
||||
294
dispatcher/examples/gemm/cpp/06_multi_registry.cpp
Normal file
294
dispatcher/examples/gemm/cpp/06_multi_registry.cpp
Normal file
@@ -0,0 +1,294 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/**
|
||||
* Example 06: Multiple Registries and Multiple Kernel Sets
|
||||
*
|
||||
* Demonstrates:
|
||||
* - Multiple DECL_KERNEL_SET declarations (each with multiple kernels)
|
||||
* - Separate Registry instances for different workload types
|
||||
* - Independent Dispatchers that select from their respective registries
|
||||
*
|
||||
* Registration patterns:
|
||||
* - REGISTER_GENERATED_KERNELS(registry, arch) -> all kernels to one registry
|
||||
* - REGISTER_KERNEL_SET("set_name", registry, arch) -> specific set by name
|
||||
* - generated::get_kernel_set_names() -> list available set names
|
||||
*
|
||||
* Build: cd dispatcher/build && cmake .. && make gemm_06_multi_registry
|
||||
* Usage: ./gemm_06_multi_registry [--list] [--help]
|
||||
*/
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/dispatcher.hpp"
|
||||
#include "ck_tile/dispatcher/kernel_decl.hpp"
|
||||
#include "ck_tile/dispatcher/example_args.hpp"
|
||||
|
||||
using namespace ck_tile::dispatcher;
|
||||
using namespace ck_tile::dispatcher::utils;
|
||||
using Signature = decl::Signature;
|
||||
using Algorithm = decl::Algorithm;
|
||||
|
||||
// =============================================================================
|
||||
// KERNEL SETS: Multiple sets with multiple kernels each
|
||||
// =============================================================================
|
||||
|
||||
// Compute-bound kernel set: Large tiles for high arithmetic intensity
|
||||
// Max tile with 32x32 warp is 128x128 (16 warps = 1024 threads)
|
||||
DECL_KERNEL_SET(compute_bound_set,
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 64) // Large tile, max for 32x32 warp
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 128, 32) // Same tile, different K for variety
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// Memory-bound kernel set: Smaller tiles for better cache efficiency
|
||||
DECL_KERNEL_SET(memory_bound_set,
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942")
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(128, 64, 32)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// Latency-optimized: Minimal overhead tiles
|
||||
DECL_KERNEL_SET(latency_set,
|
||||
.add(Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm()
|
||||
.tile(64, 64, 64)
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv3")
|
||||
.scheduler("intrawave")
|
||||
.epilogue("cshuffle"),
|
||||
"gfx942"));
|
||||
|
||||
// =============================================================================
|
||||
// MAIN
|
||||
// =============================================================================
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExampleArgs args("Example 06: Multiple Registries",
|
||||
"Separate registries for different workload types");
|
||||
args.add_flag("--list", "List all declared kernel sets");
|
||||
args.add_option("--arch", "gfx942", "GPU architecture");
|
||||
|
||||
if(!args.parse(argc, argv))
|
||||
return 0;
|
||||
|
||||
print_header("Example 06: Multiple Registries & Kernel Sets");
|
||||
|
||||
std::string gfx_arch = args.get("--arch", "gfx942");
|
||||
|
||||
// =========================================================================
|
||||
// Step 1: Show declared kernel sets (from DECL_KERNEL_SET macros)
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 1: Declared Kernel Sets\n";
|
||||
std::cout << "-----------------------------\n";
|
||||
KernelSetRegistry::instance().print();
|
||||
|
||||
if(args.has("--list"))
|
||||
{
|
||||
// Print detailed info
|
||||
for(const auto& name : KernelSetRegistry::instance().names())
|
||||
{
|
||||
const auto& set = KernelSetRegistry::instance().get(name);
|
||||
std::cout << "\n " << name << ":\n";
|
||||
for(const auto& decl : set.declarations())
|
||||
{
|
||||
std::cout << " - " << decl.name() << " (tile=" << decl.algorithm.tile_m_ << "x"
|
||||
<< decl.algorithm.tile_n_ << "x" << decl.algorithm.tile_k_ << ")\n";
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Step 2: Create registries and demonstrate MERGING
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 2: Create and Merge Registries\n";
|
||||
std::cout << "------------------------------------\n";
|
||||
|
||||
// Create individual registries first
|
||||
Registry compute_registry;
|
||||
Registry latency_registry;
|
||||
Registry memory_registry;
|
||||
|
||||
compute_registry.set_name("compute_bound");
|
||||
latency_registry.set_name("latency_optimized");
|
||||
memory_registry.set_name("memory_bound");
|
||||
|
||||
// Register kernels to individual registries using set names (no hardcoding)
|
||||
REGISTER_KERNEL_SET("compute_bound_set", compute_registry, gfx_arch);
|
||||
REGISTER_KERNEL_SET("latency_set", latency_registry, gfx_arch);
|
||||
REGISTER_KERNEL_SET("memory_bound_set", memory_registry, gfx_arch);
|
||||
|
||||
std::cout << " Individual registries:\n";
|
||||
std::cout << " compute_bound: " << compute_registry.size() << " kernel(s)\n";
|
||||
std::cout << " latency_optimized: " << latency_registry.size() << " kernel(s)\n";
|
||||
std::cout << " memory_bound: " << memory_registry.size() << " kernel(s)\n";
|
||||
|
||||
// MERGE compute + latency into a combined registry
|
||||
Registry combined_registry;
|
||||
combined_registry.set_name("compute_latency_combined");
|
||||
|
||||
// Register both sets into combined registry
|
||||
REGISTER_KERNEL_SET("compute_bound_set", combined_registry, gfx_arch);
|
||||
REGISTER_KERNEL_SET("latency_set", combined_registry, gfx_arch);
|
||||
|
||||
std::cout << "\n After merging compute + latency:\n";
|
||||
std::cout << " combined: " << combined_registry.size() << " kernel(s)\n";
|
||||
std::cout << " memory (separate): " << memory_registry.size() << " kernel(s)\n";
|
||||
|
||||
// =========================================================================
|
||||
// Step 3: Create dispatchers - one merged, one separate
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 3: Create Dispatchers\n";
|
||||
std::cout << "--------------------------\n";
|
||||
|
||||
Dispatcher combined_dispatcher(&combined_registry); // compute + latency merged
|
||||
Dispatcher memory_dispatcher(&memory_registry); // memory separate
|
||||
|
||||
std::cout << " combined_dispatcher: compute + latency kernels (" << combined_registry.size()
|
||||
<< " kernels)\n";
|
||||
std::cout << " memory_dispatcher: memory-bound kernels (" << memory_registry.size()
|
||||
<< " kernels)\n";
|
||||
|
||||
// =========================================================================
|
||||
// Step 4: Run with different dispatchers
|
||||
// =========================================================================
|
||||
std::cout << "\nStep 4: Run Workloads\n";
|
||||
print_separator();
|
||||
|
||||
using DataType = ck_tile::fp16_t;
|
||||
|
||||
struct WorkloadTest
|
||||
{
|
||||
const char* name;
|
||||
Dispatcher* dispatcher;
|
||||
int M, N, K;
|
||||
};
|
||||
|
||||
std::vector<WorkloadTest> tests = {
|
||||
{"Compute-bound (combined)", &combined_dispatcher, 4096, 4096, 4096},
|
||||
{"Memory-bound (separate)", &memory_dispatcher, 1024, 1024, 1024},
|
||||
{"Latency-opt (combined)", &combined_dispatcher, 512, 512, 512},
|
||||
};
|
||||
|
||||
bool all_passed = true;
|
||||
|
||||
for(const auto& test : tests)
|
||||
{
|
||||
Problem problem(test.M, test.N, test.K);
|
||||
|
||||
// Allocate and initialize
|
||||
GpuBuffer<DataType> a_dev(test.M * test.K);
|
||||
GpuBuffer<DataType> b_dev(test.K * test.N);
|
||||
GpuBuffer<DataType> c_dev(test.M * test.N);
|
||||
|
||||
std::vector<DataType> a_host(test.M * test.K, DataType(1.0f));
|
||||
std::vector<DataType> b_host(test.K * test.N, DataType(1.0f));
|
||||
a_dev.copy_from_host(a_host.data());
|
||||
b_dev.copy_from_host(b_host.data());
|
||||
c_dev.zero();
|
||||
|
||||
// Select kernel and run
|
||||
auto selected = test.dispatcher->select_kernel(problem);
|
||||
float time_ms =
|
||||
test.dispatcher->run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
|
||||
double tflops = calculate_tflops(test.M, test.N, test.K, time_ms);
|
||||
|
||||
std::cout << test.name << " (" << test.M << "x" << test.N << "x" << test.K << "):\n";
|
||||
if(selected)
|
||||
std::cout << " Selected: " << selected->get_name() << "\n";
|
||||
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
|
||||
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
|
||||
|
||||
// Verify ALL elements
|
||||
std::vector<DataType> c_host(test.M * test.N);
|
||||
c_dev.copy_to_host(c_host.data());
|
||||
const float expected = static_cast<float>(test.K);
|
||||
|
||||
int num_errors = 0;
|
||||
float max_error = 0.0f;
|
||||
for(int i = 0; i < test.M * test.N; ++i)
|
||||
{
|
||||
float actual = static_cast<float>(c_host[i]);
|
||||
float error = std::abs(actual - expected);
|
||||
max_error = std::max(max_error, error);
|
||||
// Allow 1% relative tolerance for FP16 accumulation
|
||||
if(error > 0.01f * expected + 1.0f)
|
||||
++num_errors;
|
||||
}
|
||||
|
||||
bool test_passed = (num_errors == 0);
|
||||
std::cout << " Verify: " << (test.M * test.N) << " elements, errors=" << num_errors
|
||||
<< "\n";
|
||||
std::cout << " Status: " << (test_passed ? "PASS" : "FAIL") << "\n\n";
|
||||
|
||||
if(!test_passed)
|
||||
all_passed = false;
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Summary
|
||||
// =========================================================================
|
||||
print_separator();
|
||||
std::cout << "Multi-Registry Pattern Summary:\n";
|
||||
print_separator();
|
||||
std::cout << R"(
|
||||
// 1. Declare multiple kernel sets
|
||||
DECL_KERNEL_SET(compute_bound_set, .add(...));
|
||||
DECL_KERNEL_SET(memory_bound_set, .add(...));
|
||||
DECL_KERNEL_SET(latency_set, .add(...));
|
||||
|
||||
// 2. Create registries and register by set NAME (no hardcoding!)
|
||||
Registry combined_reg, memory_reg;
|
||||
REGISTER_KERNEL_SET("compute_bound_set", combined_reg, arch); // Add compute
|
||||
REGISTER_KERNEL_SET("latency_set", combined_reg, arch); // Merge latency
|
||||
REGISTER_KERNEL_SET("memory_bound_set", memory_reg, arch); // Separate
|
||||
|
||||
// 3. Create dispatchers from merged/separate registries
|
||||
Dispatcher combined_disp(&combined_reg); // Has both compute + latency
|
||||
Dispatcher memory_disp(&memory_reg); // Has only memory-bound
|
||||
|
||||
// 4. Choose dispatcher based on workload
|
||||
if (problem.is_memory_bound())
|
||||
memory_disp.run(...);
|
||||
else
|
||||
combined_disp.run(...); // Handles both compute & latency workloads
|
||||
)";
|
||||
print_separator();
|
||||
std::cout << "Overall Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
|
||||
|
||||
return all_passed ? 0 : 1;
|
||||
}
|
||||
229
dispatcher/examples/gemm/cpp/README.md
Normal file
229
dispatcher/examples/gemm/cpp/README.md
Normal file
@@ -0,0 +1,229 @@
|
||||
# GEMM C++ Examples
|
||||
|
||||
CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.
|
||||
|
||||
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Build and Run
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
mkdir -p build && cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Build (kernels generated automatically by CMake)
|
||||
make -j$(nproc)
|
||||
|
||||
# Run examples
|
||||
cd examples
|
||||
./gemm_01_basic
|
||||
./gemm_03_benchmark_validation
|
||||
./gemm_04_heuristics
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
| Example | Description | Complexity |
|
||||
|---------|-------------|------------|
|
||||
| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ |
|
||||
| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ |
|
||||
| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ |
|
||||
| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ |
|
||||
| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ |
|
||||
| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ |
|
||||
|
||||
## Example Details
|
||||
|
||||
### 01_basic_gemm.cpp - Basic GEMM
|
||||
Demonstrates the declarative kernel API with three patterns:
|
||||
|
||||
1. **Autofill Pattern** - Minimal specification, defaults filled automatically
|
||||
2. **Autocorrect Pattern** - Invalid parameters corrected at build time
|
||||
3. **Full Specification Pattern** - Complete kernel configuration
|
||||
|
||||
```cpp
|
||||
DECL_KERNEL_SET(basic_kernels,
|
||||
// Pattern 1: Autofill - minimal specification
|
||||
.add(
|
||||
Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm(), // Defaults filled by autofill
|
||||
"gfx942"
|
||||
)
|
||||
// Pattern 2: Full specification
|
||||
.add(
|
||||
Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
|
||||
.pipeline("compv4").scheduler("intrawave"),
|
||||
"gfx942"
|
||||
)
|
||||
);
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Uses generic `REGISTER_GENERATED_KERNELS` macro
|
||||
- `print_registered_kernels()` utility for debugging
|
||||
- Demonstrates autofill messages during build
|
||||
|
||||
### 02_multi_size.cpp - Wildcard Expansion
|
||||
Demonstrates automatic generation of multiple kernel configurations:
|
||||
|
||||
```cpp
|
||||
DECL_KERNEL_SET(multi_kernels,
|
||||
.add(
|
||||
Signature().dtype("fp16").layout("rcr"),
|
||||
Algorithm().tile(*, *, 32) // Wildcard tile M and N
|
||||
.wave(2, 2, 1)
|
||||
.warp(32, 32, 16)
|
||||
.pipeline("compv4")
|
||||
.scheduler("intrawave"),
|
||||
"gfx942"
|
||||
)
|
||||
);
|
||||
```
|
||||
|
||||
**Wildcard Values:**
|
||||
- `*`, `-1`, or `ANY_INT` expand to all valid configurations
|
||||
- Architecture filter prunes invalid combinations automatically
|
||||
- Example generates 5 valid kernels after arch filtering (from 7 expansions)
|
||||
|
||||
### 03_benchmark_validation.cpp - Benchmark + Validation
|
||||
Consolidated example combining performance benchmarking with correctness validation:
|
||||
|
||||
```bash
|
||||
# Benchmark only
|
||||
./gemm_03_benchmark_validation --warmup 10 --iterations 100
|
||||
|
||||
# With CPU validation
|
||||
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3
|
||||
|
||||
# With GPU reference validation (faster for large matrices)
|
||||
./gemm_03_benchmark_validation --verify 2
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Warmup iterations (discarded from timing)
|
||||
- Benchmark iterations with statistics (min/max/mean/median)
|
||||
- CPU reference validation using `ck_tile::reference_gemm`
|
||||
- GPU reference validation using `ck_tile::reference_gemm_gpu`
|
||||
- Configurable tolerances
|
||||
|
||||
### 04_heuristics.cpp - Heuristic Selection
|
||||
Demonstrates custom kernel selection based on problem characteristics:
|
||||
|
||||
```cpp
|
||||
// Problem size analysis
|
||||
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
|
||||
if (p.M() * p.N() < 256 * 256) {
|
||||
return small_kernel_key; // Memory-bound heuristic
|
||||
} else {
|
||||
return large_kernel_key; // Compute-bound heuristic
|
||||
}
|
||||
};
|
||||
|
||||
dispatcher.set_heuristic(heuristic);
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Problem size analysis (small vs large matrices)
|
||||
- Compute-bound vs memory-bound selection
|
||||
- Custom heuristic function registration
|
||||
|
||||
### 05_json_export.cpp - JSON Export
|
||||
Exports registry information to JSON for external tool integration:
|
||||
|
||||
```cpp
|
||||
auto json = registry.to_json();
|
||||
std::ofstream file("kernels.json");
|
||||
file << json;
|
||||
```
|
||||
|
||||
**Use Cases:**
|
||||
- Kernel metadata serialization
|
||||
- External analysis tools
|
||||
- Configuration management
|
||||
|
||||
### 06_multi_registry.cpp - Multiple Registries
|
||||
Demonstrates using multiple registries with named kernel sets:
|
||||
|
||||
```cpp
|
||||
// Define separate kernel sets
|
||||
DECL_KERNEL_SET(compute_optimized, ...);
|
||||
DECL_KERNEL_SET(latency_optimized, ...);
|
||||
|
||||
// Register to specific registries
|
||||
Registry compute_registry, latency_registry;
|
||||
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
|
||||
REGISTER_KERNEL_SET(latency_optimized, latency_registry);
|
||||
|
||||
// Use appropriate registry based on workload
|
||||
Dispatcher compute_dispatcher(compute_registry);
|
||||
Dispatcher latency_dispatcher(latency_registry);
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Named kernel set registration with `REGISTER_KERNEL_SET` macro
|
||||
- Separate registries for different optimization goals
|
||||
- Dynamic kernel set selection by name
|
||||
|
||||
## Benchmark Parameters (stream_config)
|
||||
|
||||
CK Tile uses `stream_config` for benchmark control:
|
||||
|
||||
```cpp
|
||||
ck_tile::stream_config cfg{
|
||||
nullptr, // stream_id - HIP stream (nullptr = default)
|
||||
true, // time_kernel - Enable timing
|
||||
1, // log_level - Verbosity (0=quiet, 1=normal)
|
||||
5, // cold_niters - Warmup iterations
|
||||
20, // nrepeat - Benchmark iterations
|
||||
true, // is_gpu_timer - Use GPU events vs CPU chrono
|
||||
false, // flush_cache - Flush L2 cache between iterations
|
||||
1 // rotating_count - Rotating buffers for cache simulation
|
||||
};
|
||||
```
|
||||
|
||||
| Parameter | CLI Option | Default | Description |
|
||||
|-----------|------------|---------|-------------|
|
||||
| `cold_niters_` | `--warmup` | 5 | Warmup iterations |
|
||||
| `nrepeat_` | `--iterations` | 100 | Benchmark iterations |
|
||||
| `flush_cache_` | - | false | Flush L2 cache |
|
||||
| `rotating_count_` | - | 1 | Rotating buffers |
|
||||
| `is_gpu_timer_` | - | true | GPU timer vs CPU |
|
||||
|
||||
## Declarative Kernel Pattern
|
||||
|
||||
All examples use the declarative `DECL_KERNEL_SET` macro:
|
||||
|
||||
```cpp
|
||||
DECL_KERNEL_SET(my_kernels,
|
||||
.add(
|
||||
Signature() // WHAT: operation signature
|
||||
.dtype("fp16") // Data type
|
||||
.layout("rcr"), // Matrix layouts (A=row, B=col, C=row)
|
||||
Algorithm() // HOW: implementation details
|
||||
.tile(256, 256, 32) // Tile sizes (M, N, K)
|
||||
.wave(2, 2, 1) // Wave configuration
|
||||
.warp(32, 32, 16) // Warp tile sizes
|
||||
.pipeline("compv4") // Pipeline type
|
||||
.scheduler("intrawave"), // Scheduler type
|
||||
"gfx942" // WHERE: target architecture
|
||||
)
|
||||
);
|
||||
```
|
||||
|
||||
**Key Macros:**
|
||||
- `DECL_KERNEL_SET(name, ...)` - Declare a kernel set
|
||||
- `REGISTER_GENERATED_KERNELS` - Register all kernels from this example
|
||||
- `REGISTER_KERNEL_SET(name, registry)` - Register specific kernel set to a registry
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [Python GEMM Examples](../python/README.md)
|
||||
- [Convolution Examples](../../conv/cpp/README.md)
|
||||
- [Main Dispatcher README](../../../README.md)
|
||||
331
dispatcher/examples/gemm/python/01_basic_gemm.py
Normal file
331
dispatcher/examples/gemm/python/01_basic_gemm.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 01: Basic GEMM with Multiple Kernels
|
||||
|
||||
Demonstrates:
|
||||
1. Declaring multiple kernel configurations
|
||||
2. Printing all registered kernels
|
||||
3. Running each kernel and validating output
|
||||
4. Comparing performance across kernels
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 01_basic_gemm.py
|
||||
python3 01_basic_gemm.py --help
|
||||
python3 01_basic_gemm.py --dtype bf16
|
||||
python3 01_basic_gemm.py --size 2048
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Specification for a kernel configuration"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
|
||||
|
||||
# Define multiple kernel configurations to test (50+ kernels)
|
||||
KERNEL_SPECS = [
|
||||
# Small tiles - compv3
|
||||
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
|
||||
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
|
||||
# Small tiles - compv4
|
||||
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
|
||||
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
|
||||
# Medium tiles - compv3
|
||||
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
|
||||
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
|
||||
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
|
||||
# Medium tiles - compv4
|
||||
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
|
||||
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
|
||||
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
|
||||
# Rectangular tiles - compv3
|
||||
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
|
||||
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
|
||||
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
|
||||
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
|
||||
# Rectangular tiles - compv4
|
||||
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
|
||||
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
|
||||
# Large tiles - compv3
|
||||
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
|
||||
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
|
||||
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
|
||||
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
|
||||
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
|
||||
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
|
||||
# Large tiles - compv4
|
||||
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
|
||||
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
|
||||
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
|
||||
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
|
||||
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
|
||||
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
|
||||
# Interwave scheduler variants
|
||||
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
|
||||
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
|
||||
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
|
||||
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
|
||||
# More tile_k variations - compv3
|
||||
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
|
||||
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
|
||||
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
|
||||
# More tile_k variations - compv4
|
||||
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
|
||||
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
|
||||
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
|
||||
# Additional rectangular
|
||||
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
|
||||
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
|
||||
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
|
||||
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
|
||||
# Additional compv4 variants
|
||||
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
|
||||
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
|
||||
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
|
||||
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
|
||||
]
|
||||
|
||||
|
||||
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Create a KernelConfig from a spec"""
|
||||
# Adjust warp tiles based on tile size
|
||||
if spec.tile_m <= 64:
|
||||
warp_m, warp_n = 16, 16
|
||||
else:
|
||||
warp_m, warp_n = 32, 32
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=16,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
def print_kernel_table(specs: List[KernelSpec], dtype: str):
|
||||
"""Print a formatted table of kernel configurations"""
|
||||
print("\n" + "=" * 70)
|
||||
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
||||
print("=" * 70)
|
||||
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
|
||||
print(" " + "-" * 68)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
|
||||
)
|
||||
|
||||
print(" " + "-" * 68)
|
||||
print(f" Data type: {dtype}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Basic GEMM Example with Multiple Kernels",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 01_basic_gemm.py # Default FP16 with 4 kernels
|
||||
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
|
||||
python3 01_basic_gemm.py --size 2048 # Larger problem size
|
||||
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Problem size MxNxK (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of kernels to test (0 = all)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 01: Basic GEMM with Multiple Kernels")
|
||||
print("=" * 70)
|
||||
|
||||
# Select kernels to test
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Print all kernel configurations
|
||||
# =========================================================================
|
||||
print_kernel_table(specs, args.dtype)
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Setup and test each kernel
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 70)
|
||||
print(" RUNNING KERNELS")
|
||||
print("=" * 70)
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
M, N, K = args.size, args.size, args.size
|
||||
|
||||
results = []
|
||||
|
||||
print(f"\n Problem size: {M}x{N}x{K}\n")
|
||||
print(
|
||||
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
# Create unique test data per kernel
|
||||
np.random.seed(42 + i * 1000)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Create config and setup dispatcher
|
||||
config = create_kernel_config(spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"kernel_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
|
||||
if not setup.success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# Check if size is supported
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Run GEMM
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((spec.name, False, 0, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Validate against NumPy reference
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
|
||||
# Check if within tolerance
|
||||
passed = max_err < 1e-2
|
||||
status = "PASS" if passed else "FAIL"
|
||||
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
|
||||
)
|
||||
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Summary
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 70)
|
||||
print(" SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
passed = sum(1 for r in results if r[1])
|
||||
failed = len(results) - passed
|
||||
|
||||
print(f"\n Results: {passed}/{len(results)} kernels passed")
|
||||
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
|
||||
|
||||
if results:
|
||||
valid_results = [r for r in results if r[1]]
|
||||
if valid_results:
|
||||
best = max(valid_results, key=lambda x: x[3])
|
||||
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
|
||||
|
||||
if failed == 0:
|
||||
print("\n *** ALL KERNELS PASSED ***")
|
||||
else:
|
||||
print(f"\n *** {failed} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
149
dispatcher/examples/gemm/python/02_batch_gemm.py
Normal file
149
dispatcher/examples/gemm/python/02_batch_gemm.py
Normal file
@@ -0,0 +1,149 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 02: Batch GEMM
|
||||
|
||||
Runs multiple GEMM operations with different sizes.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 02_batch_gemm.py
|
||||
python3 02_batch_gemm.py --help
|
||||
python3 02_batch_gemm.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch GEMM Example - runs multiple sizes",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 02_batch_gemm.py # Default FP16
|
||||
python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM
|
||||
python3 02_batch_gemm.py --max-size 2048 # Limit max size
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-size",
|
||||
type=int,
|
||||
default=4096,
|
||||
help="Maximum problem size (default: 4096)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 02: Batch GEMM")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Run batch of different sizes
|
||||
# =========================================================================
|
||||
print("\nStep 2: Run Batch")
|
||||
|
||||
# Generate sizes up to max_size
|
||||
all_sizes = [
|
||||
(256, 256, 256),
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
]
|
||||
sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size]
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
total_ops = 0
|
||||
total_time = 0
|
||||
|
||||
for M, N, K in sizes:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped")
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
total_ops += 2 * M * N * K
|
||||
total_time += result.time_ms
|
||||
print(
|
||||
f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK"
|
||||
)
|
||||
else:
|
||||
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error")
|
||||
|
||||
print(" " + "-" * 60)
|
||||
|
||||
if total_time > 0:
|
||||
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
|
||||
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Batch GEMM complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
171
dispatcher/examples/gemm/python/03_benchmark.py
Normal file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 03: Benchmark
|
||||
|
||||
Performance benchmarking with compute-optimized kernel configuration.
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 03_benchmark.py
|
||||
python3 03_benchmark.py --help
|
||||
python3 03_benchmark.py --size 4096
|
||||
python3 03_benchmark.py --dtype bf16 --iterations 20
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Benchmark Example - performance testing",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 03_benchmark.py # Default benchmark suite
|
||||
python3 03_benchmark.py --size 4096 # Single size benchmark
|
||||
python3 03_benchmark.py --dtype bf16 # BF16 benchmark
|
||||
python3 03_benchmark.py --iterations 20 # More iterations
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bf16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: bf16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Single problem size MxNxK (default: run all sizes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 03: Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher with compute-optimized config
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
pipeline="compv4",
|
||||
scheduler="intrawave",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Benchmark
|
||||
# =========================================================================
|
||||
print("\nStep 2: Benchmark")
|
||||
|
||||
if args.size > 0:
|
||||
sizes = [(args.size, args.size, args.size)]
|
||||
else:
|
||||
sizes = [
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
(1024, 2048, 512),
|
||||
(2048, 1024, 2048),
|
||||
]
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n")
|
||||
|
||||
print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}")
|
||||
print(" " + "-" * 60)
|
||||
|
||||
all_tflops = []
|
||||
|
||||
for M, N, K in sizes:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
# Warmup
|
||||
for _ in range(args.warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(args.iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
min_time = min(times)
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
all_tflops.append(tflops)
|
||||
print(
|
||||
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
|
||||
if all_tflops:
|
||||
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
|
||||
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
156
dispatcher/examples/gemm/python/04_validation.py
Normal file
156
dispatcher/examples/gemm/python/04_validation.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 04: Validation
|
||||
|
||||
Validates GPU GEMM against NumPy reference.
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 04_validation.py
|
||||
python3 04_validation.py --help
|
||||
python3 04_validation.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Validator,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Validation Example - validates GPU results against NumPy",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 04_validation.py # Default FP16 validation
|
||||
python3 04_validation.py --dtype bf16 # BF16 validation
|
||||
python3 04_validation.py --rtol 1e-2 # Relaxed tolerance
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 04: Validation")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Run validation tests
|
||||
# =========================================================================
|
||||
print("\nStep 2: Validation Tests")
|
||||
|
||||
validator = Validator(rtol=args.rtol, atol=args.atol)
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
test_cases = [
|
||||
("Identity", 128, 128, 128, "identity"),
|
||||
("Small", 256, 256, 256, "random"),
|
||||
("Medium", 512, 512, 512, "random"),
|
||||
("Large", 1024, 1024, 1024, "random"),
|
||||
("Non-square", 512, 1024, 256, "random"),
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
for name, M, N, K, pattern in test_cases:
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped")
|
||||
continue
|
||||
|
||||
np.random.seed(42)
|
||||
if pattern == "identity":
|
||||
A = np.eye(M, K, dtype=np_dtype)
|
||||
B = np.eye(K, N, dtype=np_dtype)
|
||||
else:
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if not result.success:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED")
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
is_valid, max_err, _ = validator.check(result.output, C_ref)
|
||||
|
||||
if is_valid:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED")
|
||||
passed += 1
|
||||
else:
|
||||
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED")
|
||||
failed += 1
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
total = passed + failed
|
||||
print(f"Results: {passed}/{total} passed")
|
||||
print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}")
|
||||
print("=" * 60)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
166
dispatcher/examples/gemm/python/05_numpy_integration.py
Normal file
166
dispatcher/examples/gemm/python/05_numpy_integration.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 05: NumPy Integration
|
||||
|
||||
Shows how to create a GPU-accelerated matmul wrapper.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 05_numpy_integration.py
|
||||
python3 05_numpy_integration.py --help
|
||||
python3 05_numpy_integration.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Dispatcher,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
class GPUMatmul:
|
||||
"""GPU-accelerated matrix multiplication wrapper."""
|
||||
|
||||
def __init__(self, dispatcher: Dispatcher):
|
||||
self.dispatcher = dispatcher
|
||||
|
||||
def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
|
||||
"""Compute C = A @ B on GPU with CPU fallback."""
|
||||
M, K = A.shape
|
||||
K2, N = B.shape
|
||||
|
||||
if K != K2:
|
||||
raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}")
|
||||
|
||||
if not self.dispatcher.is_supported(M, N, K):
|
||||
return np.matmul(A, B)
|
||||
|
||||
result = self.dispatcher.run(A, B, M, N, K)
|
||||
return result.output if result.success else np.matmul(A, B)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="NumPy Integration Example - GPU-accelerated matmul wrapper",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 05_numpy_integration.py # Default FP16
|
||||
python3 05_numpy_integration.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 05: NumPy Integration")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Create GPU matmul wrapper
|
||||
# =========================================================================
|
||||
print("\nStep 2: Create GPUMatmul")
|
||||
|
||||
gpu_matmul = GPUMatmul(dispatcher=dispatcher)
|
||||
print(" gpu_matmul ready")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Demo - Simple multiplication using gpu_matmul
|
||||
# =========================================================================
|
||||
print("\nStep 3: Demo - Simple Multiplication")
|
||||
|
||||
A = np.random.randn(1024, 512).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(512, 256).astype(np_dtype) * 0.1
|
||||
|
||||
# Use the gpu_matmul wrapper
|
||||
C = gpu_matmul(A, B)
|
||||
print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}")
|
||||
|
||||
M, K = A.shape
|
||||
_, N = B.shape
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}")
|
||||
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Demo - FFN block
|
||||
# =========================================================================
|
||||
print("\nStep 4: Demo - FFN Block")
|
||||
|
||||
batch, hidden, ffn = 128, 768, 3072
|
||||
X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02
|
||||
W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02
|
||||
W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02
|
||||
|
||||
result1 = dispatcher.run(X, W1, batch, ffn, hidden)
|
||||
H = result1.output
|
||||
result2 = dispatcher.run(H, W2, batch, hidden, ffn)
|
||||
|
||||
print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}")
|
||||
print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("NumPy Integration Pattern:")
|
||||
print("=" * 60)
|
||||
print(" 1. setup_gemm_dispatcher(config)")
|
||||
print(" 2. GPUMatmul(dispatcher)")
|
||||
print(" 3. C = gpu_matmul(A, B)")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
169
dispatcher/examples/gemm/python/06_json_export.py
Normal file
169
dispatcher/examples/gemm/python/06_json_export.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 06: JSON Export
|
||||
|
||||
Exports registry configuration to JSON.
|
||||
|
||||
Complexity: ★★☆☆☆
|
||||
|
||||
Usage:
|
||||
python3 06_json_export.py
|
||||
python3 06_json_export.py --help
|
||||
python3 06_json_export.py --output my_kernels.json
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JSON Export Example - exports registry to JSON",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 06_json_export.py # Default output to kernels.json
|
||||
python3 06_json_export.py --output my.json # Custom output file
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
default="kernels.json",
|
||||
help="Output JSON file (default: kernels.json)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 06: JSON Export")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Dispatcher")
|
||||
|
||||
config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Define additional configs for export
|
||||
# =========================================================================
|
||||
print("\nStep 2: Define Additional Configs")
|
||||
|
||||
configs = [
|
||||
config,
|
||||
KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=256,
|
||||
tile_n=256,
|
||||
tile_k=64,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=64,
|
||||
tile_n=64,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
),
|
||||
]
|
||||
|
||||
for cfg in configs:
|
||||
print(f" - {cfg.tile_str}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Export to JSON
|
||||
# =========================================================================
|
||||
print("\nStep 3: Export to JSON")
|
||||
|
||||
export_data = {
|
||||
"registry": setup.registry.name,
|
||||
"kernel_count": len(configs),
|
||||
"kernels": [],
|
||||
}
|
||||
|
||||
for cfg in configs:
|
||||
kernel_info = {
|
||||
"tile": cfg.tile_str,
|
||||
"dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c},
|
||||
"layout": cfg.layout,
|
||||
"pipeline": cfg.pipeline,
|
||||
"target": cfg.gfx_arch,
|
||||
}
|
||||
export_data["kernels"].append(kernel_info)
|
||||
|
||||
# Include C++ library info
|
||||
if setup.lib:
|
||||
cpp_json = setup.lib.export_registry_json()
|
||||
try:
|
||||
export_data["cpp_registry"] = json.loads(cpp_json)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
json_str = json.dumps(export_data, indent=2)
|
||||
|
||||
with open(args.output, "w") as f:
|
||||
f.write(json_str)
|
||||
print(f" Saved to: {args.output}")
|
||||
|
||||
# Preview
|
||||
print("\nStep 4: Preview")
|
||||
print("-" * 60)
|
||||
print(json_str[:500] + ("..." if len(json_str) > 500 else ""))
|
||||
print("-" * 60)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("JSON Export complete!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
513
dispatcher/examples/gemm/python/07_stress_test.py
Normal file
@@ -0,0 +1,513 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 07: Stress Test - Multiple Kernels with Validation
|
||||
|
||||
Consolidated stress test that:
|
||||
1. Declares multiple kernel configurations (various tiles, pipelines, layouts)
|
||||
2. Prints all registered kernels with details
|
||||
3. Validates each kernel against NumPy reference
|
||||
4. Optional benchmarking mode
|
||||
|
||||
This tests:
|
||||
- Multiple tile sizes (64x64, 128x128, 256x256)
|
||||
- Multiple pipelines (compv3, compv4)
|
||||
- Multiple data types (fp16, bf16)
|
||||
- Different schedulers (intrawave, interwave)
|
||||
|
||||
Complexity: ★★★★☆
|
||||
|
||||
Usage:
|
||||
python3 07_stress_test.py
|
||||
python3 07_stress_test.py --help
|
||||
python3 07_stress_test.py --num-kernels 10
|
||||
python3 07_stress_test.py --benchmark
|
||||
python3 07_stress_test.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
Validator,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""A kernel specification for testing"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
wave_m: int = 2
|
||||
wave_n: int = 2
|
||||
wave_k: int = 1
|
||||
warp_m: int = 32
|
||||
warp_n: int = 32
|
||||
warp_k: int = 16
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
layout: str = "rcr"
|
||||
|
||||
def to_config(self, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Convert to KernelConfig"""
|
||||
# Adjust warp tiles for smaller tiles
|
||||
warp_m = min(self.warp_m, self.tile_m // self.wave_m)
|
||||
warp_n = min(self.warp_n, self.tile_n // self.wave_n)
|
||||
warp_k = self.warp_k
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a={"r": "row", "c": "col"}[self.layout[0]],
|
||||
layout_b={"r": "row", "c": "col"}[self.layout[1]],
|
||||
layout_c={"r": "row", "c": "col"}[self.layout[2]],
|
||||
tile_m=self.tile_m,
|
||||
tile_n=self.tile_n,
|
||||
tile_k=self.tile_k,
|
||||
wave_m=self.wave_m,
|
||||
wave_n=self.wave_n,
|
||||
wave_k=self.wave_k,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=warp_k,
|
||||
pipeline=self.pipeline,
|
||||
scheduler=self.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
# Define stress test kernel configurations
|
||||
KERNEL_SPECS = [
|
||||
# Small tiles - compv3
|
||||
KernelSpec(
|
||||
"small_compv3",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"small_compv4",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=16,
|
||||
warp_n=16,
|
||||
warp_k=32,
|
||||
pipeline="compv4",
|
||||
),
|
||||
# Medium tiles
|
||||
KernelSpec(
|
||||
"medium_compv3",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_compv4",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Rectangular tiles
|
||||
KernelSpec(
|
||||
"rect_64x128",
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64",
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
# Different schedulers
|
||||
KernelSpec(
|
||||
"interwave",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
scheduler="interwave",
|
||||
),
|
||||
# Large tiles
|
||||
KernelSpec(
|
||||
"large_compv3",
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv3",
|
||||
),
|
||||
KernelSpec(
|
||||
"large_compv4",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline="compv4",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def print_kernel_summary(specs: List[KernelSpec], dtype: str):
|
||||
"""Print a summary table of all kernel specs"""
|
||||
print("\n" + "=" * 80)
|
||||
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
|
||||
print("=" * 80)
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}"
|
||||
warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}"
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}"
|
||||
)
|
||||
|
||||
print(" " + "-" * 78)
|
||||
print(f" Data type: {dtype}\n")
|
||||
|
||||
|
||||
def validate_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
validator: Validator,
|
||||
kernel_index: int = 0,
|
||||
verbose: bool = False,
|
||||
) -> Tuple[bool, float, str]:
|
||||
"""
|
||||
Validate a single kernel configuration.
|
||||
Returns: (passed, max_error, message)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# Create config
|
||||
config = spec.to_config(dtype, arch)
|
||||
|
||||
# Setup dispatcher
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"stress_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, f"Setup failed: {setup.error}"
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, f"Size {M}x{N}x{K} not supported"
|
||||
|
||||
# Use different seed per kernel to get unique test data
|
||||
# This ensures each kernel is tested with different matrices
|
||||
np.random.seed(42 + kernel_index * 1000)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Run GPU GEMM
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
cleanup_gemm()
|
||||
return False, 0.0, "GPU execution failed"
|
||||
|
||||
# Validate against NumPy
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
is_valid, max_err, _ = validator.check(result.output, C_ref)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS"
|
||||
|
||||
|
||||
def benchmark_kernel(
|
||||
spec: KernelSpec,
|
||||
dtype: str,
|
||||
arch: str,
|
||||
size: int,
|
||||
warmup: int = 3,
|
||||
iterations: int = 10,
|
||||
) -> Tuple[bool, float, float]:
|
||||
"""
|
||||
Benchmark a kernel configuration.
|
||||
Returns: (success, avg_time_ms, tflops)
|
||||
"""
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
config = spec.to_config(dtype, arch)
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"bench_{spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
M, N, K = size, size, size
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
cleanup_gemm()
|
||||
return False, 0.0, 0.0
|
||||
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
# Warmup
|
||||
for _ in range(warmup):
|
||||
dispatcher.run(A, B, M, N, K)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(iterations):
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
if not times:
|
||||
return False, 0.0, 0.0
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
|
||||
|
||||
return True, avg_time, tflops
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Stress Test - Multiple kernels with validation",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 07_stress_test.py # Test all kernels
|
||||
python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels
|
||||
python3 07_stress_test.py --benchmark # Include benchmarks
|
||||
python3 07_stress_test.py --dtype bf16 # Test BF16
|
||||
python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-kernels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of kernels to test (0 = all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--size",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Problem size MxNxK (default: 512)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Include benchmark timing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rtol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Relative tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--atol",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="Absolute tolerance (default: 1e-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 80)
|
||||
print("Example 07: GEMM Stress Test - Multiple Kernels")
|
||||
print("=" * 80)
|
||||
|
||||
# Select kernels to test
|
||||
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
|
||||
|
||||
# Print kernel summary
|
||||
print_kernel_summary(specs, args.dtype)
|
||||
|
||||
# Run validation
|
||||
print("\n" + "=" * 80)
|
||||
print(" VALIDATION RESULTS")
|
||||
print("=" * 80)
|
||||
|
||||
validator = Validator(rtol=args.rtol, atol=args.atol)
|
||||
|
||||
if args.benchmark:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
for i, spec in enumerate(specs, 1):
|
||||
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
|
||||
|
||||
try:
|
||||
is_valid, max_err, info = validate_kernel(
|
||||
spec, args.dtype, args.arch, args.size, validator, kernel_index=i
|
||||
)
|
||||
|
||||
if is_valid:
|
||||
status = "PASS"
|
||||
passed += 1
|
||||
else:
|
||||
status = "FAIL"
|
||||
failed += 1
|
||||
|
||||
if args.benchmark:
|
||||
success, avg_time, tflops = benchmark_kernel(
|
||||
spec, args.dtype, args.arch, args.size
|
||||
)
|
||||
if success:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
skipped += 1
|
||||
print(
|
||||
f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}"
|
||||
)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SUMMARY")
|
||||
print("=" * 80)
|
||||
total = passed + failed + skipped
|
||||
print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped")
|
||||
print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}")
|
||||
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
|
||||
if failed == 0 and skipped == 0:
|
||||
print("\n *** ALL KERNELS PASSED ***")
|
||||
elif failed > 0:
|
||||
print(f"\n *** {failed} KERNELS FAILED ***")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
718
dispatcher/examples/gemm/python/08_heuristics.py
Normal file
718
dispatcher/examples/gemm/python/08_heuristics.py
Normal file
@@ -0,0 +1,718 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 08: Custom Heuristics
|
||||
|
||||
Demonstrates custom kernel selection heuristics based on problem characteristics.
|
||||
|
||||
This example shows how to:
|
||||
1. Define multiple kernel configurations for different workloads
|
||||
2. Implement custom heuristics to select the best kernel
|
||||
3. Test heuristic selection across different problem sizes
|
||||
|
||||
Heuristic strategies:
|
||||
- Size-based: Small tiles for small problems, large tiles for large problems
|
||||
- Compute-bound: Maximize compute utilization for large matrices
|
||||
- Memory-bound: Optimize memory access for bandwidth-limited cases
|
||||
- Latency-focused: Minimize kernel launch overhead for small problems
|
||||
|
||||
Complexity: ★★★★☆
|
||||
|
||||
Usage:
|
||||
python3 08_heuristics.py
|
||||
python3 08_heuristics.py --help
|
||||
python3 08_heuristics.py --strategy compute
|
||||
python3 08_heuristics.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kernel Specifications
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Kernel specification with metadata for heuristic selection"""
|
||||
|
||||
name: str
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
pipeline: str = "compv3"
|
||||
scheduler: str = "intrawave"
|
||||
# Metadata for heuristics
|
||||
category: str = "balanced" # small, balanced, large, compute, memory
|
||||
min_problem_size: int = 0
|
||||
max_problem_size: int = float("inf")
|
||||
|
||||
|
||||
# Define kernel pool for heuristic selection (20+ kernels)
|
||||
KERNEL_POOL = [
|
||||
# ==========================================================================
|
||||
# SMALL TILES - Low latency, good for small problems
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"small_64x64_k32",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"small_64x64_k64",
|
||||
64,
|
||||
64,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"small_64x64_v4",
|
||||
64,
|
||||
64,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="small",
|
||||
max_problem_size=256 * 256,
|
||||
),
|
||||
# ==========================================================================
|
||||
# MEDIUM TILES - Balanced performance
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"medium_128x128_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
max_problem_size=2048 * 2048,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_k128",
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_v4_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"medium_128x128_v4_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
# Rectangular medium tiles
|
||||
KernelSpec(
|
||||
"rect_64x128_k32",
|
||||
64,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64_k32",
|
||||
128,
|
||||
64,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_64x128_k64",
|
||||
64,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"rect_128x64_k64",
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="balanced",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
# ==========================================================================
|
||||
# LARGE TILES - High throughput for large problems
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"large_256x128_k32",
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x128_k64",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_128x256_k32",
|
||||
128,
|
||||
256,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_128x256_k64",
|
||||
128,
|
||||
256,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x256_k32",
|
||||
256,
|
||||
256,
|
||||
32,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
KernelSpec(
|
||||
"large_256x256_k64",
|
||||
256,
|
||||
256,
|
||||
64,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="large",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
# ==========================================================================
|
||||
# COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"compute_128x128_v4_k32",
|
||||
128,
|
||||
128,
|
||||
32,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_128x128_v4_k64",
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_256x128_v4",
|
||||
256,
|
||||
128,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=512 * 512,
|
||||
),
|
||||
KernelSpec(
|
||||
"compute_256x256_v4",
|
||||
256,
|
||||
256,
|
||||
64,
|
||||
"compv4",
|
||||
"intrawave",
|
||||
category="compute",
|
||||
min_problem_size=1024 * 1024,
|
||||
),
|
||||
# ==========================================================================
|
||||
# MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads
|
||||
# ==========================================================================
|
||||
KernelSpec(
|
||||
"memory_128x128_k16",
|
||||
128,
|
||||
128,
|
||||
16,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="memory",
|
||||
min_problem_size=256 * 256,
|
||||
),
|
||||
KernelSpec(
|
||||
"memory_64x128_k16",
|
||||
64,
|
||||
128,
|
||||
16,
|
||||
"compv3",
|
||||
"intrawave",
|
||||
category="memory",
|
||||
min_problem_size=128 * 128,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
|
||||
"""Create KernelConfig from spec"""
|
||||
warp_m = 16 if spec.tile_m <= 64 else 32
|
||||
warp_n = 16 if spec.tile_n <= 64 else 32
|
||||
|
||||
return KernelConfig(
|
||||
dtype_a=dtype,
|
||||
dtype_b=dtype,
|
||||
dtype_c=dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col",
|
||||
layout_c="row",
|
||||
tile_m=spec.tile_m,
|
||||
tile_n=spec.tile_n,
|
||||
tile_k=spec.tile_k,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=warp_m,
|
||||
warp_n=warp_n,
|
||||
warp_k=16,
|
||||
pipeline=spec.pipeline,
|
||||
scheduler=spec.scheduler,
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=arch,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Heuristic Strategies
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class HeuristicStrategy(Enum):
|
||||
SIZE_BASED = "size"
|
||||
COMPUTE_BOUND = "compute"
|
||||
MEMORY_BOUND = "memory"
|
||||
LATENCY_FOCUSED = "latency"
|
||||
|
||||
|
||||
def size_based_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel based on problem size.
|
||||
- Small problems: Use small tiles for low latency
|
||||
- Medium problems: Use balanced tiles
|
||||
- Large problems: Use large tiles for high throughput
|
||||
|
||||
Also considers K dimension for tile_k selection.
|
||||
"""
|
||||
total_elements = M * N
|
||||
|
||||
# Filter by problem size constraints
|
||||
candidates = [
|
||||
k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
candidates = kernels # Fall back to all kernels
|
||||
|
||||
# Determine target category based on problem size
|
||||
if total_elements < 256 * 256:
|
||||
target_category = "small"
|
||||
elif total_elements < 1024 * 1024:
|
||||
target_category = "balanced"
|
||||
else:
|
||||
target_category = "large"
|
||||
|
||||
# Filter by category if possible
|
||||
category_candidates = [k for k in candidates if k.category == target_category]
|
||||
if category_candidates:
|
||||
candidates = category_candidates
|
||||
|
||||
# Select best tile_k based on K dimension
|
||||
# Prefer tile_k that divides K well
|
||||
def tile_k_score(k):
|
||||
if K % k.tile_k == 0:
|
||||
return 0 # Perfect division
|
||||
return K % k.tile_k # Remainder (lower is better)
|
||||
|
||||
# Sort by tile_k fit, then by tile size
|
||||
candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n))
|
||||
|
||||
return candidates[0]
|
||||
|
||||
|
||||
def compute_bound_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for compute-bound workloads.
|
||||
Prefers compv4 pipeline and larger tiles.
|
||||
Selects based on problem size to maximize compute utilization.
|
||||
"""
|
||||
total_elements = M * N
|
||||
|
||||
# Prefer compute category kernels
|
||||
compute_kernels = [k for k in kernels if k.category == "compute"]
|
||||
|
||||
if not compute_kernels:
|
||||
# Fall back to compv4 kernels
|
||||
compute_kernels = [k for k in kernels if k.pipeline == "compv4"]
|
||||
|
||||
if not compute_kernels:
|
||||
compute_kernels = kernels
|
||||
|
||||
# Filter by problem size
|
||||
valid = [k for k in compute_kernels if k.min_problem_size <= total_elements]
|
||||
if valid:
|
||||
compute_kernels = valid
|
||||
|
||||
# For large problems, prefer larger tiles
|
||||
if total_elements >= 1024 * 1024:
|
||||
return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k)
|
||||
else:
|
||||
# For smaller problems, prefer medium tiles
|
||||
return min(
|
||||
compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128)
|
||||
)
|
||||
|
||||
|
||||
def memory_bound_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for memory-bound workloads.
|
||||
Prefers smaller tile_k for better memory access patterns.
|
||||
"""
|
||||
# Prefer memory category kernels first
|
||||
memory_kernels = [k for k in kernels if k.category == "memory"]
|
||||
if memory_kernels:
|
||||
# Select based on problem size
|
||||
total = M * N
|
||||
if total < 512 * 512:
|
||||
return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
# Fall back to balanced with smaller tile_k
|
||||
balanced = [k for k in kernels if k.category == "balanced"]
|
||||
if balanced:
|
||||
# Prefer smaller tile_k for memory-bound
|
||||
return min(balanced, key=lambda k: k.tile_k)
|
||||
|
||||
# Fall back to medium-sized tile with small tile_k
|
||||
return min(
|
||||
kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128))
|
||||
)
|
||||
|
||||
|
||||
def latency_focused_heuristic(
|
||||
M: int, N: int, K: int, kernels: List[KernelSpec]
|
||||
) -> KernelSpec:
|
||||
"""
|
||||
Select kernel optimized for low latency.
|
||||
Prefers smaller tiles and compv4 for faster execution.
|
||||
"""
|
||||
# Prefer small category
|
||||
small_kernels = [k for k in kernels if k.category == "small"]
|
||||
|
||||
if small_kernels:
|
||||
# Among small kernels, prefer compv4 for lower latency
|
||||
v4_small = [k for k in small_kernels if k.pipeline == "compv4"]
|
||||
if v4_small:
|
||||
return v4_small[0]
|
||||
return small_kernels[0]
|
||||
|
||||
# Fall back to smallest tile with compv4 if available
|
||||
all_v4 = [k for k in kernels if k.pipeline == "compv4"]
|
||||
if all_v4:
|
||||
return min(all_v4, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
# Fall back to smallest tile
|
||||
return min(kernels, key=lambda k: k.tile_m * k.tile_n)
|
||||
|
||||
|
||||
HEURISTICS = {
|
||||
HeuristicStrategy.SIZE_BASED: size_based_heuristic,
|
||||
HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic,
|
||||
HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic,
|
||||
HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def print_kernel_pool(kernels: List[KernelSpec]):
|
||||
"""Print available kernels"""
|
||||
print("\n" + "=" * 75)
|
||||
print(" KERNEL POOL")
|
||||
print("=" * 75)
|
||||
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}")
|
||||
print(" " + "-" * 73)
|
||||
|
||||
for i, k in enumerate(kernels, 1):
|
||||
tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}"
|
||||
print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}")
|
||||
|
||||
print(" " + "-" * 73)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Custom Heuristics Example - intelligent kernel selection",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 08_heuristics.py # Default size-based heuristic
|
||||
python3 08_heuristics.py --strategy compute # Compute-bound heuristic
|
||||
python3 08_heuristics.py --strategy memory # Memory-bound heuristic
|
||||
python3 08_heuristics.py --strategy latency # Latency-focused heuristic
|
||||
python3 08_heuristics.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--strategy",
|
||||
default="size",
|
||||
choices=["size", "compute", "memory", "latency"],
|
||||
help="Heuristic strategy (default: size)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 75)
|
||||
print("Example 08: Custom Heuristics")
|
||||
print("=" * 75)
|
||||
|
||||
# Map strategy string to enum
|
||||
strategy_map = {
|
||||
"size": HeuristicStrategy.SIZE_BASED,
|
||||
"compute": HeuristicStrategy.COMPUTE_BOUND,
|
||||
"memory": HeuristicStrategy.MEMORY_BOUND,
|
||||
"latency": HeuristicStrategy.LATENCY_FOCUSED,
|
||||
}
|
||||
strategy = strategy_map[args.strategy]
|
||||
heuristic_fn = HEURISTICS[strategy]
|
||||
|
||||
print(f"\n Strategy: {strategy.value}")
|
||||
print(f" Data type: {args.dtype}")
|
||||
|
||||
# Print kernel pool
|
||||
print_kernel_pool(KERNEL_POOL)
|
||||
|
||||
# =========================================================================
|
||||
# Test heuristic selection across different problem sizes
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 75)
|
||||
print(" HEURISTIC SELECTION TEST")
|
||||
print("=" * 75)
|
||||
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
test_sizes = [
|
||||
(128, 128, 64), # Small
|
||||
(256, 256, 128), # Small-medium
|
||||
(512, 512, 256), # Medium
|
||||
(1024, 1024, 512), # Medium-large
|
||||
(2048, 2048, 1024), # Large
|
||||
]
|
||||
|
||||
print(
|
||||
f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
|
||||
)
|
||||
print(" " + "-" * 78)
|
||||
|
||||
results = []
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
# Use heuristic to select kernel
|
||||
selected_spec = heuristic_fn(M, N, K, KERNEL_POOL)
|
||||
|
||||
# Create config and setup
|
||||
config = create_kernel_config(selected_spec, args.dtype, args.arch)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=config,
|
||||
registry_name=f"heuristic_{selected_spec.name}",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
size_str = f"{M}x{N}x{K}"
|
||||
|
||||
if not setup.success:
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Run GEMM
|
||||
np.random.seed(42)
|
||||
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
|
||||
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if not result.success:
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
)
|
||||
results.append((size_str, selected_spec.name, False, 0, 0))
|
||||
cleanup_gemm()
|
||||
continue
|
||||
|
||||
# Validate
|
||||
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
|
||||
max_err = np.max(np.abs(result.output - C_ref))
|
||||
passed = max_err < 1e-2
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
print(
|
||||
f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
|
||||
)
|
||||
results.append(
|
||||
(size_str, selected_spec.name, passed, result.time_ms, result.tflops)
|
||||
)
|
||||
|
||||
cleanup_gemm()
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print("\n" + "=" * 75)
|
||||
print(" SUMMARY")
|
||||
print("=" * 75)
|
||||
|
||||
passed = sum(1 for r in results if r[2])
|
||||
failed = len(results) - passed
|
||||
|
||||
print(f"\n Strategy: {strategy.value}")
|
||||
print(f" Results: {passed}/{len(results)} tests passed")
|
||||
|
||||
# Show kernel selection distribution
|
||||
kernel_usage = {}
|
||||
for r in results:
|
||||
kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1
|
||||
|
||||
print("\n Kernel Selection Distribution:")
|
||||
for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]):
|
||||
print(f" {kernel}: {count} times")
|
||||
|
||||
if results:
|
||||
valid_results = [r for r in results if r[2]]
|
||||
if valid_results:
|
||||
avg_tflops = sum(r[4] for r in valid_results) / len(valid_results)
|
||||
print(f"\n Average TFLOPS: {avg_tflops:.2f}")
|
||||
|
||||
if failed == 0:
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
else:
|
||||
print(f"\n *** {failed} TESTS FAILED ***")
|
||||
|
||||
print("=" * 75)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
220
dispatcher/examples/gemm/python/09_multi_registry.py
Normal file
220
dispatcher/examples/gemm/python/09_multi_registry.py
Normal file
@@ -0,0 +1,220 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: Multiple Registries
|
||||
|
||||
Demonstrates multiple registries for different optimization targets.
|
||||
|
||||
Complexity: ★★★★★
|
||||
|
||||
Usage:
|
||||
python3 09_multi_registry.py
|
||||
python3 09_multi_registry.py --help
|
||||
python3 09_multi_registry.py --dtype bf16
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
Registry,
|
||||
Dispatcher,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multiple Registries Example - optimization-specific registries",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 09_multi_registry.py # Default FP16
|
||||
python3 09_multi_registry.py --dtype bf16 # BF16 mode
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="fp16",
|
||||
choices=["fp16", "bf16", "fp32"],
|
||||
help="Data type (default: fp16)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 60)
|
||||
print("Example 09: Multiple Registries")
|
||||
print("=" * 60)
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Setup base dispatcher
|
||||
# =========================================================================
|
||||
print("\nStep 1: Setup Base Dispatcher")
|
||||
|
||||
base_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True)
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
lib = setup.lib
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Define configs for different optimization targets
|
||||
# =========================================================================
|
||||
print("\nStep 2: Define Optimization Targets")
|
||||
|
||||
compute_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=256,
|
||||
tile_n=256,
|
||||
tile_k=64,
|
||||
wave_m=4,
|
||||
wave_n=4,
|
||||
pipeline="compv4",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
memory_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
pipeline="compv4",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
latency_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
tile_m=64,
|
||||
tile_n=64,
|
||||
tile_k=32,
|
||||
wave_m=1,
|
||||
wave_n=1,
|
||||
pipeline="compv3",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
|
||||
print(f" Compute: {compute_config.tile_str} (large matrices)")
|
||||
print(f" Memory: {memory_config.tile_str} (medium matrices)")
|
||||
print(f" Latency: {latency_config.tile_str} (small matrices)")
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Create registries
|
||||
# =========================================================================
|
||||
print("\nStep 3: Create Registries")
|
||||
|
||||
compute_registry = Registry(name="compute", lib=lib)
|
||||
compute_registry.register_kernel(compute_config)
|
||||
|
||||
memory_registry = Registry(name="memory", lib=lib)
|
||||
memory_registry.register_kernel(memory_config)
|
||||
|
||||
latency_registry = Registry(name="latency", lib=lib)
|
||||
latency_registry.register_kernel(latency_config)
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Create dispatchers
|
||||
# =========================================================================
|
||||
print("\nStep 4: Create Dispatchers")
|
||||
|
||||
compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib)
|
||||
memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib)
|
||||
latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib)
|
||||
|
||||
print(f" {compute_dispatcher}")
|
||||
print(f" {memory_dispatcher}")
|
||||
print(f" {latency_dispatcher}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 5: Smart dispatcher selection
|
||||
# =========================================================================
|
||||
print("\nStep 5: Smart Dispatcher Selection")
|
||||
|
||||
def select_dispatcher(M: int, N: int, K: int) -> Dispatcher:
|
||||
elements = M * N
|
||||
if elements >= 4096 * 4096:
|
||||
return compute_dispatcher
|
||||
elif elements >= 1024 * 1024:
|
||||
return memory_dispatcher
|
||||
else:
|
||||
return latency_dispatcher
|
||||
|
||||
test_sizes = [
|
||||
(256, 256, 256),
|
||||
(512, 512, 512),
|
||||
(1024, 1024, 1024),
|
||||
(2048, 2048, 2048),
|
||||
(4096, 4096, 4096),
|
||||
]
|
||||
|
||||
print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}")
|
||||
print(" " + "-" * 55)
|
||||
|
||||
for M, N, K in test_sizes:
|
||||
dispatcher = select_dispatcher(M, N, K)
|
||||
|
||||
if not dispatcher.is_supported(M, N, K):
|
||||
continue
|
||||
|
||||
A = np.random.randn(M, K).astype(np_dtype) * 0.1
|
||||
B = np.random.randn(K, N).astype(np_dtype) * 0.1
|
||||
|
||||
result = dispatcher.run(A, B, M, N, K)
|
||||
|
||||
if result.success:
|
||||
print(
|
||||
f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} "
|
||||
f"{result.time_ms:>12.4f} {result.tflops:>10.2f}"
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Multi-Registry Pattern:")
|
||||
print("=" * 60)
|
||||
print(" 1. Define KernelConfig for each optimization target")
|
||||
print(" 2. Create Registry for each target")
|
||||
print(" 3. Register configs to appropriate registries")
|
||||
print(" 4. Create Dispatcher for each registry")
|
||||
print(" 5. Select dispatcher based on problem characteristics")
|
||||
print(" 6. Run GEMM with selected dispatcher")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
260
dispatcher/examples/gemm/python/10_advanced_benchmark.py
Normal file
260
dispatcher/examples/gemm/python/10_advanced_benchmark.py
Normal file
@@ -0,0 +1,260 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 10: Advanced Benchmarking with Full Control
|
||||
|
||||
This example demonstrates all available benchmark parameters:
|
||||
- warmup: Number of warmup iterations (default: 5)
|
||||
- repeat: Number of benchmark iterations (default: 20)
|
||||
- flush_cache: Flush GPU cache between iterations (default: False)
|
||||
- timer: Timer type - "gpu" (default) or "cpu"
|
||||
- init: Initialization method - "random", "linear", "constant"
|
||||
|
||||
Usage:
|
||||
python3 10_advanced_benchmark.py
|
||||
python3 10_advanced_benchmark.py --warmup 10 --repeat 100
|
||||
python3 10_advanced_benchmark.py --init linear
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add paths for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctypes_utils import (
|
||||
KernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Advanced GEMM benchmarking with full parameter control"
|
||||
)
|
||||
|
||||
# Problem size
|
||||
parser.add_argument("-m", type=int, default=2048, help="M dimension")
|
||||
parser.add_argument("-n", type=int, default=2048, help="N dimension")
|
||||
parser.add_argument("-k", type=int, default=2048, help="K dimension")
|
||||
|
||||
# Benchmark parameters
|
||||
parser.add_argument(
|
||||
"--warmup", type=int, default=5, help="Number of warmup iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat", type=int, default=20, help="Number of benchmark iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache", action="store_true", help="Flush GPU cache between iterations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init",
|
||||
choices=["random", "linear", "constant"],
|
||||
default="random",
|
||||
help="Initialization method",
|
||||
)
|
||||
|
||||
# Kernel configuration
|
||||
parser.add_argument("--dtype", default="fp16", help="Data type")
|
||||
parser.add_argument("--pipeline", default="compv4", help="Pipeline type")
|
||||
parser.add_argument("--arch", default="gfx942", help="GPU architecture")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def initialize_matrix(shape, method, dtype):
|
||||
"""Initialize matrix with specified method"""
|
||||
if method == "random":
|
||||
return np.random.randn(*shape).astype(dtype) * 0.5
|
||||
elif method == "linear":
|
||||
total = np.prod(shape)
|
||||
return np.arange(total).reshape(shape).astype(dtype) / total
|
||||
elif method == "constant":
|
||||
return np.ones(shape, dtype=dtype)
|
||||
else:
|
||||
return np.random.randn(*shape).astype(dtype)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print("=" * 70)
|
||||
print("Example 10: Advanced GEMM Benchmarking")
|
||||
print("=" * 70)
|
||||
|
||||
# Show benchmark configuration
|
||||
print("\nBenchmark Configuration:")
|
||||
print(f" Problem Size: {args.m} x {args.n} x {args.k}")
|
||||
print(f" Warmup: {args.warmup} iterations")
|
||||
print(f" Repeat: {args.repeat} iterations")
|
||||
print(f" Flush Cache: {args.flush_cache}")
|
||||
print(f" Timer: {args.timer}")
|
||||
print(f" Init Method: {args.init}")
|
||||
print(f" Data Type: {args.dtype}")
|
||||
print(f" Pipeline: {args.pipeline}")
|
||||
print(f" Architecture: {args.arch}")
|
||||
print()
|
||||
|
||||
# Map dtype
|
||||
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
# Initialize matrices
|
||||
print("Step 1: Initialize matrices...")
|
||||
A = initialize_matrix((args.m, args.k), args.init, np_dtype)
|
||||
B = initialize_matrix((args.k, args.n), args.init, np_dtype)
|
||||
print(f" A: {A.shape} ({args.init})")
|
||||
print(f" B: {B.shape} ({args.init})")
|
||||
|
||||
# Create kernel config (does not include M/N/K - those are problem size)
|
||||
print("\nStep 2: Create kernel configuration...")
|
||||
kernel_config = KernelConfig(
|
||||
dtype_a=args.dtype,
|
||||
dtype_b=args.dtype,
|
||||
dtype_c=args.dtype,
|
||||
dtype_acc="fp32",
|
||||
layout_a="row",
|
||||
layout_b="col", # B is column-major for optimal performance
|
||||
layout_c="row",
|
||||
tile_m=128,
|
||||
tile_n=128,
|
||||
tile_k=32,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_m=32,
|
||||
warp_n=32,
|
||||
warp_k=16,
|
||||
pipeline=args.pipeline,
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle",
|
||||
gfx_arch=args.arch,
|
||||
)
|
||||
print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}")
|
||||
|
||||
# Setup dispatcher
|
||||
print("\nStep 3: Setup dispatcher...")
|
||||
setup = setup_gemm_dispatcher(
|
||||
config=kernel_config,
|
||||
registry_name="benchmark_gemm",
|
||||
verbose=False,
|
||||
auto_rebuild=True,
|
||||
)
|
||||
|
||||
if not setup.success:
|
||||
print(f" ERROR: {setup.error}")
|
||||
return 1
|
||||
|
||||
dispatcher = setup.dispatcher
|
||||
print(f" Library: {setup.lib.path if setup.lib else 'N/A'}")
|
||||
print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}")
|
||||
|
||||
# Run benchmark with multiple iterations
|
||||
print("\nStep 4: Run benchmark...")
|
||||
print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...")
|
||||
|
||||
# Warmup
|
||||
for _ in range(args.warmup):
|
||||
_ = dispatcher.run(A, B, args.m, args.n, args.k)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(args.repeat):
|
||||
result = dispatcher.run(A, B, args.m, args.n, args.k)
|
||||
if result.success:
|
||||
times.append(result.time_ms)
|
||||
|
||||
if times:
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
|
||||
# Calculate TFLOPS
|
||||
flops = 2 * args.m * args.n * args.k
|
||||
avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0
|
||||
max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0
|
||||
|
||||
# Calculate bandwidth (C has same dtype as A and B)
|
||||
C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize
|
||||
bandwidth_gb = (
|
||||
(A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000)
|
||||
if avg_time > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***")
|
||||
print(f" Average Time: {avg_time:.4f} ms")
|
||||
print(f" Min Time: {min_time:.4f} ms")
|
||||
print(f" Max Time: {max_time:.4f} ms")
|
||||
print(f" Avg TFLOPS: {avg_tflops:.2f}")
|
||||
print(f" Peak TFLOPS: {max_tflops:.2f}")
|
||||
print(f" Bandwidth: {bandwidth_gb:.2f} GB/s")
|
||||
else:
|
||||
print(" FAILED: No successful runs")
|
||||
return 1
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK PARAMETERS REFERENCE")
|
||||
print("=" * 70)
|
||||
print("""
|
||||
Available parameters for GEMM benchmarking:
|
||||
|
||||
--warmup N Number of warmup iterations (discard results)
|
||||
Higher = more stable results, longer run time
|
||||
Default: 5
|
||||
|
||||
--repeat N Number of benchmark iterations
|
||||
Higher = more accurate average, longer run time
|
||||
Default: 20
|
||||
|
||||
--flush-cache Flush GPU L2 cache between iterations
|
||||
Use for memory-bound benchmarks
|
||||
Default: off
|
||||
|
||||
--timer {gpu,cpu} Timer type
|
||||
gpu = HIP events (more accurate for GPU)
|
||||
cpu = std::chrono (includes kernel launch overhead)
|
||||
Default: gpu
|
||||
|
||||
--init METHOD Matrix initialization
|
||||
random = uniform random [-0.5, 0.5]
|
||||
linear = sequential values
|
||||
constant = all ones
|
||||
Default: random
|
||||
|
||||
Note: For C++ examples, these parameters are passed to stream_config:
|
||||
|
||||
ck_tile::stream_config cfg{
|
||||
nullptr, // stream_id
|
||||
true, // time_kernel
|
||||
1, // log_level
|
||||
5, // cold_niters (warmup)
|
||||
20, // nrepeat
|
||||
true, // is_gpu_timer
|
||||
false, // flush_cache
|
||||
1 // rotating_count
|
||||
};
|
||||
""")
|
||||
|
||||
# Cleanup
|
||||
cleanup_gemm()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
310
dispatcher/examples/gemm/python/11_json_import.py
Normal file
310
dispatcher/examples/gemm/python/11_json_import.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 11: JSON-based Kernel Configuration Import
|
||||
|
||||
Demonstrates loading kernel configurations from JSON files, similar to tile_engine.
|
||||
This enables easy customization of kernel sets without modifying code.
|
||||
|
||||
Key Features:
|
||||
- Load tile configs from JSON (compatible with tile_engine format)
|
||||
- Generate kernel sets from configuration
|
||||
- Use arch_filter validation on loaded configs
|
||||
- Export to C++ DECL_KERNEL_SET format
|
||||
|
||||
Complexity: ★★★☆☆
|
||||
|
||||
Usage:
|
||||
python3 11_json_import.py
|
||||
python3 11_json_import.py --config my_kernels.json
|
||||
python3 11_json_import.py --export-cpp
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
# Add codegen to path for kernel_config_loader
|
||||
script_dir = Path(__file__).parent.resolve()
|
||||
sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen"))
|
||||
sys.path.insert(0, str(script_dir.parent.parent.parent / "python"))
|
||||
|
||||
from kernel_config_loader import ( # noqa: E402
|
||||
load_kernel_configs,
|
||||
KernelConfig,
|
||||
generate_cpp_kernel_set_declaration,
|
||||
)
|
||||
|
||||
from ctypes_utils import ( # noqa: E402
|
||||
KernelConfig as DispatcherKernelConfig,
|
||||
setup_gemm_dispatcher,
|
||||
cleanup_gemm,
|
||||
reset_for_example,
|
||||
validate_kernel_config,
|
||||
)
|
||||
|
||||
# Sample JSON configuration (embedded for demonstration)
|
||||
SAMPLE_JSON_CONFIG = {
|
||||
"_comment": "Sample kernel configuration for GEMM",
|
||||
"kernel_set_name": "inference_kernels",
|
||||
"datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"},
|
||||
"layout": "rcr",
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [128, 256]},
|
||||
"tile_n": {"values": [128, 256]},
|
||||
"tile_k": {"values": [32]},
|
||||
"warp_m": {"values": [2]},
|
||||
"warp_n": {"values": [2]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [32]},
|
||||
"warp_tile_n": {"values": [32]},
|
||||
"warp_tile_k": {"values": [16]},
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv4"]},
|
||||
"scheduler": {"values": ["intrawave"]},
|
||||
"epilogue": {"values": ["cshuffle"]},
|
||||
"pad_m": {"values": [False]},
|
||||
"pad_n": {"values": [False]},
|
||||
"pad_k": {"values": [False]},
|
||||
},
|
||||
"gpu_targets": ["gfx942"],
|
||||
}
|
||||
|
||||
|
||||
def print_section(title: str):
|
||||
"""Print a section header"""
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f" {title}")
|
||||
print(f"{'=' * 70}\n")
|
||||
|
||||
|
||||
def convert_to_dispatcher_config(
|
||||
config: KernelConfig, arch: str = "gfx942"
|
||||
) -> DispatcherKernelConfig:
|
||||
"""Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig"""
|
||||
return DispatcherKernelConfig(
|
||||
dtype_a=config.dtype_a,
|
||||
dtype_b=config.dtype_b,
|
||||
dtype_c=config.dtype_c,
|
||||
dtype_acc=config.dtype_acc,
|
||||
tile_m=config.tile.tile_m,
|
||||
tile_n=config.tile.tile_n,
|
||||
tile_k=config.tile.tile_k,
|
||||
wave_m=config.tile.warp_m,
|
||||
wave_n=config.tile.warp_n,
|
||||
wave_k=config.tile.warp_k,
|
||||
warp_m=config.tile.warp_tile_m,
|
||||
warp_n=config.tile.warp_tile_n,
|
||||
warp_k=config.tile.warp_tile_k,
|
||||
pipeline=config.trait.pipeline,
|
||||
scheduler=config.trait.scheduler,
|
||||
epilogue=config.trait.epilogue,
|
||||
pad_m=config.trait.pad_m,
|
||||
pad_n=config.trait.pad_n,
|
||||
pad_k=config.trait.pad_k,
|
||||
gfx_arch=arch,
|
||||
variant=config.variant,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="JSON Kernel Configuration Import Example",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python3 11_json_import.py # Use embedded sample config
|
||||
python3 11_json_import.py --config my.json # Load from file
|
||||
python3 11_json_import.py --export-cpp # Generate C++ declarations
|
||||
python3 11_json_import.py --validate # Validate configs against arch
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to JSON configuration file (uses embedded sample if not provided)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--export-cpp",
|
||||
action="store_true",
|
||||
help="Export kernel set as C++ DECL_KERNEL_SET",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate",
|
||||
action="store_true",
|
||||
help="Validate all configurations against arch filter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
default="gfx942",
|
||||
help="Target GPU architecture (default: gfx942)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
reset_for_example()
|
||||
|
||||
print_section("Example 11: JSON Kernel Configuration Import")
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: Load configuration from JSON
|
||||
# =========================================================================
|
||||
print("Step 1: Load Kernel Configuration from JSON")
|
||||
print("-" * 50)
|
||||
|
||||
if args.config:
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
print(f" ERROR: Config file not found: {config_path}")
|
||||
return 1
|
||||
print(f" Loading from: {config_path}")
|
||||
config_set = load_kernel_configs(config_path)
|
||||
else:
|
||||
# Use embedded sample config
|
||||
print(" Using embedded sample configuration")
|
||||
# Write to temp file and load
|
||||
temp_path = Path("/tmp/sample_gemm_config.json")
|
||||
with open(temp_path, "w") as f:
|
||||
json.dump(SAMPLE_JSON_CONFIG, f, indent=2)
|
||||
config_set = load_kernel_configs(temp_path)
|
||||
|
||||
print(f"\n Kernel Set Name: {config_set.name}")
|
||||
print(
|
||||
f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}"
|
||||
)
|
||||
print(f" Layout: {config_set.layout}")
|
||||
print(f" GPU Targets: {config_set.gpu_targets}")
|
||||
print(f" Total Configurations: {config_set.config_count()}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Display configuration details
|
||||
# =========================================================================
|
||||
print("\nStep 2: Configuration Details")
|
||||
print("-" * 50)
|
||||
|
||||
print("\n Tile Configurations:")
|
||||
print(f" tile_m: {config_set.tile_m_values}")
|
||||
print(f" tile_n: {config_set.tile_n_values}")
|
||||
print(f" tile_k: {config_set.tile_k_values}")
|
||||
print(
|
||||
f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}"
|
||||
)
|
||||
print(
|
||||
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
|
||||
)
|
||||
|
||||
print("\n Trait Configurations:")
|
||||
print(f" pipeline: {config_set.pipeline_values}")
|
||||
print(f" scheduler: {config_set.scheduler_values}")
|
||||
print(f" epilogue: {config_set.epilogue_values}")
|
||||
print(
|
||||
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Generate and display kernel names
|
||||
# =========================================================================
|
||||
print("\nStep 3: Generated Kernel Names")
|
||||
print("-" * 50)
|
||||
|
||||
configs = list(config_set.generate_configs())
|
||||
for i, config in enumerate(configs[:5]):
|
||||
print(f" {i + 1}. {config.kernel_name()}")
|
||||
if len(configs) > 5:
|
||||
print(f" ... and {len(configs) - 5} more configurations")
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Validate against arch filter (optional)
|
||||
# =========================================================================
|
||||
if args.validate:
|
||||
print("\nStep 4: Architecture Validation")
|
||||
print("-" * 50)
|
||||
|
||||
valid_count = 0
|
||||
invalid_count = 0
|
||||
|
||||
for config in configs:
|
||||
disp_config = convert_to_dispatcher_config(config, args.arch)
|
||||
result = validate_kernel_config(disp_config)
|
||||
|
||||
if result.is_valid:
|
||||
valid_count += 1
|
||||
else:
|
||||
invalid_count += 1
|
||||
if invalid_count <= 3: # Show first 3 invalid
|
||||
print(f"\n ✗ Invalid: {config.kernel_name()}")
|
||||
for error in result.errors:
|
||||
print(f" Error: {error}")
|
||||
|
||||
print("\n Validation Summary:")
|
||||
print(f" ✓ Valid: {valid_count}")
|
||||
print(f" ✗ Invalid: {invalid_count}")
|
||||
print(f" Total: {len(configs)}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 5: Export to C++ (optional)
|
||||
# =========================================================================
|
||||
if args.export_cpp:
|
||||
print("\nStep 5: C++ Export")
|
||||
print("-" * 50)
|
||||
print("\n // Generated DECL_KERNEL_SET from JSON config:")
|
||||
print(" // " + "=" * 56)
|
||||
cpp_code = generate_cpp_kernel_set_declaration(config_set)
|
||||
for line in cpp_code.split("\n"):
|
||||
print(f" {line}")
|
||||
|
||||
# =========================================================================
|
||||
# Step 6: Use first config with dispatcher (demo)
|
||||
# =========================================================================
|
||||
print("\nStep 6: Dispatcher Integration Demo")
|
||||
print("-" * 50)
|
||||
|
||||
if configs:
|
||||
first_config = configs[0]
|
||||
disp_config = convert_to_dispatcher_config(first_config, args.arch)
|
||||
|
||||
print(
|
||||
f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}"
|
||||
)
|
||||
|
||||
setup = setup_gemm_dispatcher(
|
||||
disp_config, registry_name="json_import", verbose=False
|
||||
)
|
||||
if setup.success:
|
||||
print(" ✓ Dispatcher setup successful")
|
||||
print(
|
||||
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
|
||||
)
|
||||
else:
|
||||
print(f" ⚠ Dispatcher setup: {setup.error}")
|
||||
print(" (This is expected if kernels aren't generated)")
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print_section("Summary")
|
||||
print(" JSON configuration allows easy kernel set customization:")
|
||||
print(" - Define tile sizes and ranges")
|
||||
print(" - Specify trait combinations (pipeline, scheduler, etc.)")
|
||||
print(" - Target multiple GPU architectures")
|
||||
print(" - Export to C++ DECL_KERNEL_SET for static compilation")
|
||||
print()
|
||||
print(" JSON Format (tile_engine compatible):")
|
||||
print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},')
|
||||
print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}')
|
||||
print()
|
||||
print(" Usage:")
|
||||
print(" config_set = load_kernel_configs('my_kernels.json')")
|
||||
print(" for config in config_set.generate_configs():")
|
||||
print(" # Use config for codegen or dispatcher setup")
|
||||
|
||||
cleanup_gemm()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
299
dispatcher/examples/gemm/python/README.md
Normal file
299
dispatcher/examples/gemm/python/README.md
Normal file
@@ -0,0 +1,299 @@
|
||||
# GEMM Python Examples
|
||||
|
||||
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
|
||||
|
||||
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Build Library
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
mkdir -p build && cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Build Python library (kernels generated automatically)
|
||||
make dispatcher_gemm_lib -j$(nproc)
|
||||
```
|
||||
|
||||
### Run Examples
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
python3 examples/gemm/python/04_validation.py
|
||||
python3 examples/gemm/python/07_stress_test.py
|
||||
python3 examples/gemm/python/08_heuristics.py
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
| Example | Description |
|
||||
|---------|-------------|
|
||||
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
|
||||
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
|
||||
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
|
||||
| [04_validation.py](04_validation.py) | CPU reference validation |
|
||||
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
|
||||
| [06_json_export.py](06_json_export.py) | Registry JSON export |
|
||||
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
|
||||
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
|
||||
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
|
||||
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
|
||||
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
|
||||
|
||||
## Example Details
|
||||
|
||||
### 01_basic_gemm.py - Basic GEMM
|
||||
Demonstrates the Python API with multi-kernel support:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define multiple kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(
|
||||
tile_m=128, tile_n=128, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv3", scheduler="intrawave"
|
||||
),
|
||||
KernelConfig(
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv4", scheduler="intrawave"
|
||||
),
|
||||
]
|
||||
|
||||
# Display configurations
|
||||
print_kernel_config_table(kernels)
|
||||
|
||||
# Set up dispatcher with all kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
|
||||
# Run GEMM
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
### 02_batch_gemm.py - Batch GEMM
|
||||
Batched matrix multiplication:
|
||||
- Multiple independent GEMM operations
|
||||
- Batch dimension handling
|
||||
|
||||
### 03_benchmark.py - Benchmarking
|
||||
Performance measurement:
|
||||
- GPU timing
|
||||
- TFLOPS calculation
|
||||
- Multiple iterations
|
||||
|
||||
### 04_validation.py - Validation
|
||||
Correctness verification:
|
||||
- NumPy reference implementation
|
||||
- Tolerance-based validation
|
||||
- Error reporting
|
||||
|
||||
### 05_numpy_integration.py - NumPy Integration
|
||||
Seamless NumPy integration:
|
||||
- NumPy arrays to GPU buffers
|
||||
- Results back to NumPy
|
||||
- Automatic type conversion
|
||||
|
||||
### 06_json_export.py - JSON Export
|
||||
Registry serialization for tool integration:
|
||||
- Export kernel configurations
|
||||
- Machine-readable format
|
||||
|
||||
### 07_stress_test.py - Stress Testing
|
||||
Comprehensive multi-kernel stress testing:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define 48 unique kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
|
||||
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
|
||||
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
|
||||
# ... many more configurations
|
||||
]
|
||||
|
||||
# Test each kernel
|
||||
for i, kernel in enumerate(kernels):
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
|
||||
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
|
||||
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 48 unique kernel configurations
|
||||
- Various tile sizes, pipelines, and schedulers
|
||||
- Per-kernel validation with unique random seeds
|
||||
- Performance reporting
|
||||
|
||||
### 08_heuristics.py - Heuristic Selection
|
||||
Custom kernel selection based on problem characteristics:
|
||||
|
||||
```python
|
||||
# Define kernel pools for different strategies
|
||||
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
|
||||
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
|
||||
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
|
||||
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
|
||||
|
||||
# Size-based heuristic
|
||||
def size_based_heuristic(M, N, K):
|
||||
if M * N < 512 * 512:
|
||||
return SMALL_KERNELS
|
||||
else:
|
||||
return LARGE_KERNELS
|
||||
|
||||
# Strategy-based selection
|
||||
def compute_strategy():
|
||||
return COMPUTE_KERNELS # Optimized for compute-bound problems
|
||||
|
||||
def memory_strategy():
|
||||
return MEMORY_KERNELS # Optimized for memory-bound problems
|
||||
|
||||
# Test different strategies
|
||||
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
|
||||
kernels = strategy(M, N, K)
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 24 kernel configurations across 6 categories
|
||||
- Size-based heuristic (small vs large)
|
||||
- Optimization strategies (compute, memory, latency)
|
||||
- Performance comparison across strategies
|
||||
|
||||
### 09_multi_registry.py - Multiple Registries
|
||||
Separate registries for different workloads:
|
||||
- Compute-optimized registry
|
||||
- Latency-optimized registry
|
||||
- Dynamic registry selection
|
||||
|
||||
### 10_advanced_benchmark.py - Advanced Benchmark
|
||||
Full control over benchmark parameters:
|
||||
- Warmup iterations
|
||||
- Benchmark iterations
|
||||
- Statistical analysis
|
||||
|
||||
### 11_json_import.py - JSON Import
|
||||
Import kernel configurations from JSON:
|
||||
- External configuration files
|
||||
- Dynamic kernel loading
|
||||
|
||||
## Utility Module: ctypes_utils.py
|
||||
|
||||
```python
|
||||
from ctypes_utils import (
|
||||
KernelConfig, # Single kernel configuration
|
||||
setup_gemm_dispatcher, # Set up dispatcher with kernels
|
||||
print_kernel_config_table, # Display kernel configurations
|
||||
Dispatcher, # High-level dispatcher
|
||||
Registry, # Kernel registry
|
||||
Validator, # Validation utilities
|
||||
)
|
||||
```
|
||||
|
||||
### KernelConfig
|
||||
|
||||
```python
|
||||
config = KernelConfig(
|
||||
# Tile sizes
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
# Wave configuration
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
# Warp tile sizes
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
# Pipeline and scheduler
|
||||
pipeline="compv4", # "compv3" or "compv4"
|
||||
scheduler="intrawave", # "intrawave" or "interwave"
|
||||
# Optional
|
||||
epilogue="default",
|
||||
padding=True,
|
||||
double_buffer=True,
|
||||
)
|
||||
```
|
||||
|
||||
### setup_gemm_dispatcher
|
||||
|
||||
```python
|
||||
# Single kernel
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config)
|
||||
|
||||
# Multiple kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
|
||||
|
||||
# With auto-rebuild
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
|
||||
```
|
||||
|
||||
### print_kernel_config_table
|
||||
|
||||
```python
|
||||
kernels = [config1, config2, config3]
|
||||
print_kernel_config_table(kernels)
|
||||
# Output:
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | # | Tile | Wave | Warp | Pipe | Scheduler |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
|
||||
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
|
||||
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
```
|
||||
|
||||
### GPU Memory Management
|
||||
|
||||
```python
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
# Load HIP library
|
||||
hip = ctypes.CDLL("libamdhip64.so")
|
||||
|
||||
# Allocate GPU memory
|
||||
gpu_ptr = ctypes.c_void_p()
|
||||
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
|
||||
|
||||
# Copy to GPU (1 = hipMemcpyHostToDevice)
|
||||
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
|
||||
|
||||
# Copy back (2 = hipMemcpyDeviceToHost)
|
||||
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
|
||||
|
||||
# Free
|
||||
hip.hipFree(gpu_ptr)
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
Test compilation performance with different kernel counts:
|
||||
|
||||
```bash
|
||||
# Test with 10 kernels (~15s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 10
|
||||
|
||||
# Test with 20 kernels (~25s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 20
|
||||
|
||||
# Test with 48 kernels (~50s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 48
|
||||
```
|
||||
|
||||
Compilation time scales roughly linearly with kernel count.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [C++ GEMM Examples](../cpp/README.md)
|
||||
- [Python Conv Examples](../../conv/python/README.md)
|
||||
- [Main Dispatcher README](../../../README.md)
|
||||
80
dispatcher/examples/gemm/python/kernels.json
Normal file
80
dispatcher/examples/gemm/python/kernels.json
Normal file
@@ -0,0 +1,80 @@
|
||||
{
|
||||
"registry": "export_demo",
|
||||
"kernel_count": 3,
|
||||
"kernels": [
|
||||
{
|
||||
"tile": "128x128x32",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
},
|
||||
{
|
||||
"tile": "256x256x64",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
},
|
||||
{
|
||||
"tile": "64x64x32",
|
||||
"dtypes": {
|
||||
"A": "fp16",
|
||||
"B": "fp16",
|
||||
"C": "fp16"
|
||||
},
|
||||
"layout": "rcr",
|
||||
"pipeline": "compv4",
|
||||
"target": "gfx942"
|
||||
}
|
||||
],
|
||||
"cpp_registry": {
|
||||
"metadata": {
|
||||
"timestamp": "Dec 4 2025 06:23:15",
|
||||
"total_kernels": 1,
|
||||
"export_version": "1.0",
|
||||
"dispatcher_version": "1.0.0"
|
||||
},
|
||||
"statistics": {
|
||||
"by_datatype": {},
|
||||
"by_pipeline": {},
|
||||
"by_scheduler": {}
|
||||
},
|
||||
"kernels": [
|
||||
{
|
||||
"identifier": "128x128x32_2x2x1_32x32x16_nopers",
|
||||
"name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16",
|
||||
"algorithm": {
|
||||
"tile_shape": {
|
||||
"m": 128,
|
||||
"n": 128,
|
||||
"k": 32
|
||||
},
|
||||
"wave_shape": {
|
||||
"m": 2,
|
||||
"n": 2,
|
||||
"k": 1
|
||||
},
|
||||
"warp_tile_shape": {
|
||||
"m": 32,
|
||||
"n": 32,
|
||||
"k": 16
|
||||
},
|
||||
"block_size": 256,
|
||||
"persistent": false,
|
||||
"double_buffer": true,
|
||||
"preshuffle": false,
|
||||
"transpose_c": false
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user