mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Merge remote-tracking branch 'origin/ginolu/add_wgmfma_dispatcher' into mtgu/cktile_mxfp4_flatmm_dev
This commit is contained in:
112
.github/scripts/therock_configure_ci.py
vendored
Normal file
112
.github/scripts/therock_configure_ci.py
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Iterable, Optional, Mapping
|
||||
|
||||
def gha_set_output(vars: Mapping[str, str | Path]):
|
||||
"""Sets values in a step's output parameters.
|
||||
|
||||
This appends to the file located at the $GITHUB_OUTPUT environment variable.
|
||||
|
||||
See
|
||||
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
|
||||
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
|
||||
"""
|
||||
print(f"Setting github output:\n{vars}")
|
||||
|
||||
step_output_file = os.getenv("GITHUB_OUTPUT")
|
||||
if not step_output_file:
|
||||
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
|
||||
return
|
||||
|
||||
with open(step_output_file, "a") as f:
|
||||
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
|
||||
|
||||
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
|
||||
"""Returns the paths of modified files relative to the base reference."""
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "diff", "--name-only", base_ref],
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
).stdout.splitlines()
|
||||
except TimeoutError:
|
||||
print(
|
||||
"Computing modified files timed out. Not using PR diff to determine"
|
||||
" jobs to run.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
# Paths matching any of these patterns are considered to have no influence over
|
||||
# build or test workflows so any related jobs can be skipped if all paths
|
||||
# modified by a commit/PR match a pattern in this list.
|
||||
SKIPPABLE_PATH_PATTERNS = [
|
||||
"docs/*",
|
||||
"*.gitignore",
|
||||
"*.md",
|
||||
"*.pre-commit-config.*",
|
||||
"*LICENSE",
|
||||
'Jenkinsfile',
|
||||
'.github/ISSUE_TEMPLATE/*',
|
||||
'.github/CODEOWNERS',
|
||||
'.github/*.md',
|
||||
'.github/dependabot.yml',
|
||||
]
|
||||
|
||||
def is_path_skippable(path: str) -> bool:
|
||||
"""Determines if a given relative path to a file matches any skippable patterns."""
|
||||
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
|
||||
|
||||
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if at least one path is not in the skippable set."""
|
||||
if paths is None:
|
||||
return False
|
||||
return any(not is_path_skippable(p) for p in paths)
|
||||
|
||||
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if CI workflows should run given a list of modified paths."""
|
||||
|
||||
if paths is None:
|
||||
print("No files were modified, skipping TheRock CI jobs")
|
||||
return False
|
||||
|
||||
paths_set = set(paths)
|
||||
github_workflows_paths = set(
|
||||
[p for p in paths if p.startswith(".github/workflows")]
|
||||
)
|
||||
other_paths = paths_set - github_workflows_paths
|
||||
|
||||
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
|
||||
|
||||
print("should_ci_run_given_modified_paths findings:")
|
||||
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
|
||||
|
||||
if contains_other_non_skippable_files:
|
||||
print("Enabling TheRock CI jobs since a non-skippable path was modified")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
|
||||
)
|
||||
return False
|
||||
|
||||
def main(args):
|
||||
base_ref = args.get("base_ref")
|
||||
modified_paths = get_modified_paths(base_ref)
|
||||
print("modified_paths (max 200):", modified_paths[:200])
|
||||
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
|
||||
output = {
|
||||
'enable_therock_ci': json.dumps(enable_jobs)
|
||||
}
|
||||
gha_set_output(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {}
|
||||
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
|
||||
main(args)
|
||||
130
.github/workflows/therock-ci-linux.yml
vendored
Normal file
130
.github/workflows/therock-ci-linux.yml
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
name: TheRock CI Linux
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
cmake_options:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
therock-build-linux:
|
||||
name: Build Linux Packages
|
||||
runs-on: azure-linux-scale-rocm
|
||||
permissions:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
|
||||
options: -v /runner/config:/home/awsconfig/
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
TEATIME_FORCE_INTERACTIVE: 0
|
||||
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
|
||||
steps:
|
||||
- name: Checkout composable_kernel repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Checkout TheRock repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57
|
||||
path: "TheRock"
|
||||
|
||||
- name: Runner Health Settings
|
||||
run: |
|
||||
df -h
|
||||
cmake --version
|
||||
echo "Installed Python versions:"
|
||||
ls -d /opt/python
|
||||
echo "python: $(which python), python3: $(which python3)"
|
||||
echo "Git version: $(git --version)"
|
||||
git config --global --add safe.directory $PWD
|
||||
git config fetch.parallel 10
|
||||
|
||||
- name: Fetch sources
|
||||
run: |
|
||||
./TheRock/build_tools/fetch_sources.py --jobs 12
|
||||
|
||||
- name: Install python deps
|
||||
run: |
|
||||
pip install -r TheRock/requirements.txt
|
||||
pip freeze
|
||||
|
||||
- name: Configure Projects
|
||||
env:
|
||||
amdgpu_families: ${{ env.AMDGPU_FAMILIES }}
|
||||
package_version: ADHOCBUILD
|
||||
extra_cmake_options: ${{ inputs.cmake_options }}
|
||||
BUILD_DIR: build
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/build_configure.py
|
||||
|
||||
- name: Build TheRock
|
||||
run: cmake --build TheRock/build
|
||||
|
||||
- name: Build therock-archives
|
||||
run: cmake --build TheRock/build --target therock-archives
|
||||
|
||||
- name: Report
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
echo "Full SDK du:"
|
||||
echo "------------"
|
||||
du -h -d 1 TheRock/build/dist/rocm
|
||||
echo "Artifact Archives:"
|
||||
echo "------------------"
|
||||
ls -lh TheRock/build/artifacts/*.tar.xz
|
||||
echo "Artifacts:"
|
||||
echo "----------"
|
||||
du -h -d 1 TheRock/build/artifacts
|
||||
|
||||
- name: Configure AWS Credentials for non-forked repos
|
||||
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
|
||||
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
|
||||
with:
|
||||
aws-region: us-east-2
|
||||
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
|
||||
|
||||
- name: Create Logs index Files and upload logs
|
||||
if: always()
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/create_log_index.py \
|
||||
--build-dir=TheRock/build \
|
||||
--amdgpu-family=${{ env.AMDGPU_FAMILIES }}
|
||||
|
||||
python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \
|
||||
--build-dir=TheRock/build \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }}
|
||||
|
||||
- name: Upload artifacts
|
||||
run: |
|
||||
python TheRock/build_tools/github_actions/upload_build_artifacts.py \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
|
||||
--build-dir TheRock/build
|
||||
|
||||
- name: Add Links to Job Summary
|
||||
if: always()
|
||||
run: |
|
||||
python TheRock/build_tools/github_actions/upload_build_summary.py \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
|
||||
--build-dir TheRock/build
|
||||
|
||||
therock-test-linux:
|
||||
name: "Test"
|
||||
needs: [therock-build-linux]
|
||||
uses: ./.github/workflows/therock-test-packages.yml
|
||||
with:
|
||||
project_to_test: "miopen"
|
||||
amdgpu_families: ${{ inputs.amdgpu_families }}
|
||||
test_runs_on: ${{ inputs.test_runs_on }}
|
||||
platform: "linux"
|
||||
81
.github/workflows/therock-ci.yml
vendored
Normal file
81
.github/workflows/therock-ci.yml
vendored
Normal file
@@ -0,0 +1,81 @@
|
||||
name: TheRock CI for composable_kernel
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
branches:
|
||||
- mainline
|
||||
- release/*
|
||||
- release-staging/*
|
||||
- develop
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
# A PR number if a pull request and otherwise the commit hash. This cancels
|
||||
# queued and in-progress runs for the same PR (presubmit) or commit
|
||||
# (postsubmit). The workflow name is prepended to avoid conflicts between
|
||||
# different workflows.
|
||||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
# The commit being checked out is the merge commit for a PR. Its first
|
||||
# parent will be the tip of the base branch.
|
||||
BASE_REF: HEAD^
|
||||
outputs:
|
||||
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
# We need the parent commit to do a diff
|
||||
fetch-depth: 2
|
||||
|
||||
- name: "Configuring CI options"
|
||||
id: configure
|
||||
run: python .github/scripts/therock_configure_ci.py
|
||||
|
||||
therock-ci-linux:
|
||||
name: TheRock CI Linux
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
uses: ./.github/workflows/therock-ci-linux.yml
|
||||
secrets: inherit
|
||||
with:
|
||||
cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../"
|
||||
amdgpu_families: "gfx94X-dcgpu"
|
||||
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
|
||||
|
||||
therock_ci_summary:
|
||||
name: TheRock CI Summary
|
||||
if: always()
|
||||
needs:
|
||||
- setup
|
||||
- therock-ci-linux
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Output failed jobs
|
||||
run: |
|
||||
echo '${{ toJson(needs) }}'
|
||||
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
|
||||
| jq --raw-output \
|
||||
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
|
||||
)"
|
||||
if [[ "${FAILED_JOBS}" != "" ]]; then
|
||||
echo "The following jobs failed: ${FAILED_JOBS}"
|
||||
exit 1
|
||||
fi
|
||||
77
.github/workflows/therock-test-packages.yml
vendored
Normal file
77
.github/workflows/therock-test-packages.yml
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
name: TheRock Test Packages
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
project_to_test:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
platform:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
configure_test_matrix:
|
||||
name: "Configure test matrix"
|
||||
runs-on: ubuntu-24.04
|
||||
if: ${{ inputs.test_runs_on != '' }}
|
||||
outputs:
|
||||
components: ${{ steps.configure.outputs.components }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
|
||||
- name: "Configuring CI options"
|
||||
env:
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
project_to_test: ${{ inputs.project_to_test }}
|
||||
id: configure
|
||||
run: python ./build_tools/github_actions/fetch_test_configurations.py
|
||||
|
||||
test_components:
|
||||
name: 'Test ${{ matrix.components.job_name }}'
|
||||
runs-on: ${{ inputs.test_runs_on }}
|
||||
needs: configure_test_matrix
|
||||
# skip tests if no test matrix to run
|
||||
if: ${{ needs.configure_test_matrix.outputs.components != '[]' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
env:
|
||||
VENV_DIR: ${{ github.workspace }}/.venv
|
||||
ARTIFACT_RUN_ID: "${{ github.run_id }}"
|
||||
OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build
|
||||
THEROCK_BIN_DIR: "./build/bin"
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
|
||||
- name: Run setup test environment workflow
|
||||
uses: './.github/actions/setup_test_environment'
|
||||
with:
|
||||
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
|
||||
VENV_DIR: ${{ env.VENV_DIR }}
|
||||
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
|
||||
|
||||
- name: Test
|
||||
timeout-minutes: ${{ matrix.components.timeout_minutes }}
|
||||
run: |
|
||||
if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi
|
||||
${{ matrix.components.test_script }}
|
||||
@@ -3,7 +3,7 @@ repos:
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
entry: clang-format-12 -i --style=file
|
||||
entry: clang-format-18 -i --style=file
|
||||
language: system
|
||||
types_or: [c++, inc]
|
||||
- id: copyright-year-checker
|
||||
|
||||
14
CHANGELOG.md
14
CHANGELOG.md
@@ -2,10 +2,11 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.5.0
|
||||
## Composable Kernel 1.2.0 for ROCm 7.0.0
|
||||
|
||||
### Added
|
||||
|
||||
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
|
||||
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
|
||||
* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced).
|
||||
@@ -19,10 +20,14 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv)
|
||||
* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd)
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added support for elementwise kernel.
|
||||
* Added benchmarking support for tile engine GEMM Multi D.
|
||||
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
|
||||
|
||||
### Optimized
|
||||
|
||||
@@ -44,11 +49,16 @@ None
|
||||
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
|
||||
|
||||
### Known issues
|
||||
|
||||
None
|
||||
|
||||
### Upcoming changes
|
||||
|
||||
* Non-grouped convolutions are deprecated. All of their functionality is supported by grouped convolution.
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.1.0
|
||||
|
||||
### Additions
|
||||
|
||||
@@ -16,12 +16,21 @@ else()
|
||||
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
|
||||
endif()
|
||||
|
||||
# Allow user to specify the C++ standard.
|
||||
# We must support C++17 builds until downstream users are migrated to C++20, but we default to C++20.
|
||||
set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)")
|
||||
set(valid_cxx_standards 17 20)
|
||||
set_property(CACHE CK_CXX_STANDARD PROPERTY STRINGS ${valid_cxx_standards})
|
||||
if(NOT CK_CXX_STANDARD IN_LIST valid_cxx_standards)
|
||||
message(FATAL_ERROR "CK_CXX_STANDARD must be one of ${valid_cxx_standards}")
|
||||
endif()
|
||||
|
||||
# Default installation path
|
||||
if(NOT WIN32)
|
||||
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
|
||||
endif()
|
||||
|
||||
set(version 1.1.0)
|
||||
set(version 1.2.0)
|
||||
# Check support for CUDA/HIP in Cmake
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
include(CTest)
|
||||
@@ -98,6 +107,12 @@ add_compile_options(-Wno-pass-failed)
|
||||
add_compile_options(-Wno-switch-default)
|
||||
add_compile_options(-Wno-unique-object-duplication)
|
||||
|
||||
# add -Og -gdwarf64 for debug builds
|
||||
add_compile_options(
|
||||
"$<$<CONFIG:Debug>:-Og>"
|
||||
"$<$<CONFIG:Debug>:-gdwarf64>"
|
||||
)
|
||||
|
||||
# Recent change in compiler makes this warning ON by default, which led to compile errors.
|
||||
add_compile_options(-Wno-nrvo)
|
||||
|
||||
@@ -215,11 +230,24 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
|
||||
add_definitions(-DCK_GFX1030_SUPPORT)
|
||||
endif()
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
set(CK_TILE_USE_WMMA 1)
|
||||
endif()
|
||||
|
||||
# define the macro with the current value (0 or 1)
|
||||
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_WMMA_FP8)
|
||||
@@ -236,6 +264,8 @@ endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
|
||||
set(CK_USE_NATIVE_MX_SUPPORT "ON")
|
||||
add_definitions(-DCK_GFX950_SUPPORT)
|
||||
set(CK_GFX950_SUPPORT "ON")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
@@ -316,12 +346,6 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_GFX11)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
endif()
|
||||
|
||||
if(ENABLE_ASM_DUMP)
|
||||
add_compile_options(--save-temps)
|
||||
add_compile_options(-Wno-gnu-line-marker)
|
||||
@@ -334,7 +358,7 @@ find_package(Threads REQUIRED)
|
||||
link_libraries(Threads::Threads)
|
||||
|
||||
## C++
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD ${CK_CXX_STANDARD})
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
|
||||
|
||||
@@ -22,6 +22,9 @@ Xiaoyan Zhou, 2020
|
||||
[Jianfeng Yan](https://github.com/j4yan), 2021-2022
|
||||
[Jun Liu](https://github.com/junliume), 2021-2024
|
||||
|
||||
[John Shumway](https://github.com/shumway), [Vidyasagar Ananthan](https://github.com/vidyasagar-amd), [Christopher Millette](https://github.com/cgmillette), [Maksim Podkorytov](https://github.com/tenpercent), [Thomas Ning](https://github.com/ThomasNing),[Andriy Roshchenko](https://github.com/andriy-ca), [Aviral Goel](https://github.com/AviralGoelAMD), [Cong Ma](https://github.com/CongMa13),[Thrupti Raj Lakshmana Gowda](https://github.com/ThruptiRajLakshmanaGowda), [Emily Martins](https://github.com/ecamartins), [Khushbu Agarwal](https://github.com/amd-khushbu), [Sudhir Kylasa](https://github.com/kylasa), [Jia Luo](https://github.com/JiaLuo-CAN), 2025-
|
||||
|
||||
|
||||
## Product Manager
|
||||
[John Afaganis](https://github.com/afagaj)
|
||||
|
||||
|
||||
@@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
libzstd-dev \
|
||||
openssh-server \
|
||||
clang-format-12 \
|
||||
clang-format-18 \
|
||||
kmod && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
21
Dockerfile.aiter
Normal file
21
Dockerfile.aiter
Normal file
@@ -0,0 +1,21 @@
|
||||
ARG BASE_DOCKER="rocm/pytorch:latest"
|
||||
FROM $BASE_DOCKER
|
||||
ARG AITER_BRANCH="main"
|
||||
ARG CK_AITER_BRANCH="develop"
|
||||
RUN groupadd -g 109 render && \
|
||||
usermod -u 1001 jenkins && \
|
||||
groupmod -g 1001 jenkins && \
|
||||
pip install pandas zmq einops && \
|
||||
pip install numpy==1.26.2 && \
|
||||
sudo mkdir /home/jenkins && \
|
||||
sudo mkdir /home/jenkins/workspace && \
|
||||
cd /home/jenkins/workspace && \
|
||||
rm -rf aiter && \
|
||||
git clone -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \
|
||||
cd aiter && \
|
||||
rm -rf 3rdparty/composable_kernel/ && \
|
||||
git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \
|
||||
python3 setup.py develop && \
|
||||
chown -R jenkins:jenkins /home/jenkins/workspace && \
|
||||
chmod -R a+rwx /home/jenkins/workspace && \
|
||||
sudo usermod -aG irc jenkins
|
||||
23
Dockerfile.pytorch
Normal file
23
Dockerfile.pytorch
Normal file
@@ -0,0 +1,23 @@
|
||||
ARG BASE_DOCKER="rocm/pytorch-nightly:latest"
|
||||
FROM $BASE_DOCKER
|
||||
ARG CK_PYTORCH_BRANCH="develop"
|
||||
RUN groupadd -g 109 render && \
|
||||
usermod -u 1001 jenkins && \
|
||||
groupmod -g 1001 jenkins && \
|
||||
cd /tmp/pytorch && \
|
||||
rm -rf build && \
|
||||
cd /tmp/pytorch/third_party && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/aiter/3rdparty && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/fbgemm/external && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/flash-attention/csrc && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
chown -R jenkins:jenkins /tmp/pytorch && \
|
||||
chmod -R a+rwx /tmp/pytorch && \
|
||||
sudo usermod -aG irc jenkins
|
||||
573
Jenkinsfile
vendored
573
Jenkinsfile
vendored
@@ -33,9 +33,6 @@ def nthreads() {
|
||||
def nproc = sh(returnStdout: true, script: 'nproc')
|
||||
echo "Number of cores: ${nproc}"
|
||||
def n = nproc.toInteger()
|
||||
if (n > 32){
|
||||
n /= 2
|
||||
}
|
||||
if (n > 64){
|
||||
n = 64
|
||||
}
|
||||
@@ -188,12 +185,20 @@ def buildDocker(install_prefix){
|
||||
if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . "
|
||||
}
|
||||
else{
|
||||
else if(params.RUN_AITER_TESTS){
|
||||
image_name = "rocm/composable_kernel:ck_aiter"
|
||||
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
|
||||
}
|
||||
else if(params.RUN_PYTORCH_TESTS){
|
||||
image_name = "rocm/composable_kernel:ck_pytorch"
|
||||
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
|
||||
}
|
||||
else{
|
||||
dockerArgs = dockerArgs + " -f Dockerfile . "
|
||||
}
|
||||
echo "Build Args: ${dockerArgs}"
|
||||
try{
|
||||
if(params.BUILD_DOCKER){
|
||||
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS || params.RUN_PYTORCH_TESTS){
|
||||
//force building the new docker if that parameter is true
|
||||
echo "Building image: ${image_name}"
|
||||
retimage = docker.build("${image_name}", dockerArgs)
|
||||
@@ -234,11 +239,6 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
def build_type_debug = (conf.get("build_type",'release') == 'debug')
|
||||
|
||||
// use special compiler for gfx950
|
||||
if ( check_arch() == 7){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
@@ -401,8 +401,9 @@ def cmake_build(Map conf=[:]){
|
||||
echo "Build packages"
|
||||
sh 'ninja -j64 package'
|
||||
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
|
||||
stash includes: "composablekernel-**.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
else{
|
||||
@@ -439,34 +440,6 @@ def cmake_build(Map conf=[:]){
|
||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||
}
|
||||
}
|
||||
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_transpose_*.log"
|
||||
if (arch_type == 1){
|
||||
stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a"
|
||||
}
|
||||
else if (arch_type == 2){
|
||||
stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942"
|
||||
}
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||
}
|
||||
}
|
||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_tile_gemm_**.log"
|
||||
if (arch == 1){
|
||||
stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a"
|
||||
}
|
||||
else if (arch == 2){
|
||||
stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942"
|
||||
}
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def buildHipClangJob(Map conf=[:]){
|
||||
@@ -489,7 +462,9 @@ def buildHipClangJob(Map conf=[:]){
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
|
||||
// newer clang22 compilers and running with older hip runtima libraries
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
|
||||
}
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
@@ -547,7 +522,9 @@ def Build_CK(Map conf=[:]){
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
|
||||
// newer clang22 compilers and running with older hip runtima libraries
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
|
||||
}
|
||||
if(params.BUILD_LEGACY_OS){
|
||||
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
|
||||
@@ -596,50 +573,66 @@ def Build_CK(Map conf=[:]){
|
||||
python3 -m pytest python/test/test_gen_instances.py
|
||||
"""
|
||||
}
|
||||
dir("build"){
|
||||
if (params.RUN_FULL_QA && arch == 2 ){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'make -j package'
|
||||
archiveArtifacts artifacts: 'composablekernel*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-**.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && arch == 1){
|
||||
// run full tests on gfx90a
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm.log"
|
||||
archiveArtifacts "perf_resnet50_N256.log"
|
||||
archiveArtifacts "perf_resnet50_N4.log"
|
||||
archiveArtifacts "perf_batched_gemm.log"
|
||||
archiveArtifacts "perf_grouped_gemm.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
|
||||
archiveArtifacts "perf_gemm_bilinear.log"
|
||||
archiveArtifacts "perf_reduction.log"
|
||||
archiveArtifacts "perf_splitK_gemm.log"
|
||||
archiveArtifacts "perf_onnx_gemm.log"
|
||||
archiveArtifacts "perf_mixed_gemm.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
|
||||
archiveArtifacts "perf_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
|
||||
archiveArtifacts "perf_batched_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx90a.log"
|
||||
archiveArtifacts "perf_gemm_bilinear_gfx90a.log"
|
||||
archiveArtifacts "perf_reduction_gfx90a.log"
|
||||
archiveArtifacts "perf_splitK_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_mixed_gemm_gfx90a.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx90a"
|
||||
}
|
||||
if (params.RUN_FULL_QA && arch == 2){
|
||||
// run full tests on gfx942
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
|
||||
archiveArtifacts "perf_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx942.log"
|
||||
archiveArtifacts "perf_batched_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx942.log"
|
||||
archiveArtifacts "perf_gemm_bilinear_gfx942.log"
|
||||
archiveArtifacts "perf_reduction_gfx942.log"
|
||||
archiveArtifacts "perf_splitK_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_mixed_gemm_gfx942.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx942"
|
||||
}
|
||||
else if ( arch == 1 ){
|
||||
// run standard tests on gfx90a
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm.log"
|
||||
archiveArtifacts "perf_onnx_gemm.log"
|
||||
archiveArtifacts "perf_resnet50_N256.log"
|
||||
archiveArtifacts "perf_resnet50_N4.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
|
||||
archiveArtifacts "perf_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx90a"
|
||||
}
|
||||
else if ( arch == 2 ){
|
||||
// run standard tests on gfx942
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
|
||||
archiveArtifacts "perf_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx942.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx942"
|
||||
}
|
||||
// disable performance tests on gfx1030 for now.
|
||||
//else if ( arch == 3){
|
||||
@@ -757,47 +750,64 @@ def process_results(Map conf=[:]){
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx942"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx942: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
|
||||
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_CK_TILE_TRANSPOSE_TESTS){
|
||||
try{
|
||||
unstash "perf_transpose_log_gfx942"
|
||||
unstash "perf_transpose_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the Transpose performance logs: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_CK_TILE_GEMM_TESTS){
|
||||
try{
|
||||
unstash "perf_tile_gemm_log_gfx942"
|
||||
unstash "perf_tile_gemm_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the GEMM performance logs: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
||||
if (params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
unstash "packages"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
unstash "perf_log"
|
||||
try{
|
||||
unstash "perf_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx90a performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx942"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx942 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx950"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx950 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx908"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx908 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx11"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx11 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
|
||||
unstash "perf_log_gfx12"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
|
||||
echo "could not locate the gfx12 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
sh "./process_perf_data.sh"
|
||||
}
|
||||
// process the logs
|
||||
sh "./process_perf_data.sh"
|
||||
}
|
||||
}
|
||||
catch(e){
|
||||
@@ -812,13 +822,121 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
def run_aiter_tests(Map conf=[:]){
|
||||
show_node_info()
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
//use the latest pytorch image
|
||||
def image = "rocm/composable_kernel:ck_aiter"
|
||||
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.pull()
|
||||
}
|
||||
}
|
||||
catch(Exception ex)
|
||||
{
|
||||
error "Unable to locate image: ${image}"
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 2, unit: 'HOURS'){
|
||||
try{
|
||||
sh "rocminfo"
|
||||
sh "python3 --version"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_ep.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting_mxfp4.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_tkw1.py"
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while running AITER tests"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
echo "Finished running AITER tests"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def run_pytorch_tests(Map conf=[:]){
|
||||
show_node_info()
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
//use the latest pytorch-nightly image
|
||||
def image = "rocm/composable_kernel:ck_pytorch"
|
||||
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.pull()
|
||||
}
|
||||
}
|
||||
catch(Exception ex)
|
||||
{
|
||||
error "Unable to locate image: ${image}"
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 2, unit: 'HOURS'){
|
||||
try{
|
||||
sh "rocminfo"
|
||||
sh "python3 --version"
|
||||
sh "python3 /tmp/pytorch/tools/amd_build/build_amd.py"
|
||||
sh "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while building Pytorch"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
echo "Finished building Pytorch"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false
|
||||
0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -885,6 +1003,10 @@ pipeline {
|
||||
name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the grouped conv large cases tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CONV_COMPREHENSIVE_DATASET",
|
||||
defaultValue: false,
|
||||
description: "Run comprehensive convolution dataset tests before important changes (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CODEGEN_TESTS",
|
||||
defaultValue: true,
|
||||
@@ -893,14 +1015,6 @@ pipeline {
|
||||
name: "RUN_CK_TILE_FMHA_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile FMHA tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_TRANSPOSE_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile Transpose tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_CK_TILE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run the ck_tile GEMM tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_TILE_ENGINE_GEMM_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -919,8 +1033,8 @@ pipeline {
|
||||
description: "Build CK and run tests on gfx90a (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX942",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx942 (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx942 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX950",
|
||||
defaultValue: false,
|
||||
@@ -957,6 +1071,26 @@ pipeline {
|
||||
name: "RUN_ALL_UNIT_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run all unit tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_PYTORCH_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Try building PYTORCH with latest CK develop branch (default: OFF)")
|
||||
string(
|
||||
name: 'ck_pytorch_branch',
|
||||
defaultValue: 'develop',
|
||||
description: 'Specify which branch of CK to test with Pytorch (default: develop)')
|
||||
booleanParam(
|
||||
name: "RUN_AITER_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run AITER tests with latest CK develop branch (default: OFF)")
|
||||
string(
|
||||
name: 'aiter_branch',
|
||||
defaultValue: 'main',
|
||||
description: 'Specify which branch of AITER to use (default: main)')
|
||||
string(
|
||||
name: 'ck_aiter_branch',
|
||||
defaultValue: 'develop',
|
||||
description: 'Specify which branch of CK to test with AITER (default: develop)')
|
||||
}
|
||||
environment{
|
||||
dbuser = "${dbuser}"
|
||||
@@ -999,7 +1133,8 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
|
||||
| grep -v 'include/rapidjson' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \
|
||||
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
|
||||
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \
|
||||
-D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
|
||||
@@ -1028,7 +1163,8 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
|
||||
| grep -v 'include/rapidjson' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'"
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
|
||||
@@ -1036,6 +1172,42 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run Pytorch Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Pytorch Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PYTORCH_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942")}
|
||||
steps{
|
||||
run_pytorch_tests()
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run AITER Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run AITER Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_AITER_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942")}
|
||||
steps{
|
||||
run_aiter_tests()
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run Grouped Conv Large Case Tests")
|
||||
{
|
||||
@@ -1051,8 +1223,40 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \
|
||||
./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases"""
|
||||
make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \
|
||||
./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run Comprehensive Convolution Dataset Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Comprehensive Dataset Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CONV_COMPREHENSIVE_DATASET.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cd ../build && \
|
||||
../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
|
||||
cd ../test_data && \
|
||||
# Dataset generation modes:
|
||||
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
|
||||
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
|
||||
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
|
||||
./generate_test_dataset.sh half && \
|
||||
cd ../build && \
|
||||
./bin/test_grouped_convnd_fwd_dataset_xdl"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1128,94 +1332,6 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run CK_TILE_TRANSPOSE Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run CK_TILE_TRANSPOSE Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 tile_example_batched_transpose && \
|
||||
cd ../ &&
|
||||
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run CK_TILE_TRANSPOSE Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
||||
make -j64 tile_example_batched_transpose && \
|
||||
cd ../ &&
|
||||
example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run CK_TILE_GEMM Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run CK_TILE_GEMM Tests on gfx90a")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx90a") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 tile_example_gemm_universal && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
stage("Run CK_TILE_GEMM Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942") }
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \
|
||||
make -j64 tile_example_gemm_universal && \
|
||||
cd ../ &&
|
||||
example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run TILE_ENGINE_GEMM Tests")
|
||||
{
|
||||
parallel
|
||||
@@ -1234,11 +1350,21 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D GEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j64 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
ninja -j64 benchmark_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1259,11 +1385,21 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D GEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j128 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j128 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
ninja -j64 benchmark_gemm_all && \
|
||||
python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1288,6 +1424,7 @@ pipeline {
|
||||
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
|
||||
setup_args = """ -DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_CXX_STANDARD="17" \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
@@ -1352,12 +1489,12 @@ pipeline {
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
@@ -1422,7 +1559,7 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
|
||||
}
|
||||
cleanWs()
|
||||
}
|
||||
@@ -1456,7 +1593,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1101") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx11-generic" \
|
||||
@@ -1477,7 +1614,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx1201") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx12-generic" \
|
||||
@@ -1499,7 +1636,7 @@ pipeline {
|
||||
stage("Process results"){
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent { label 'mici' }
|
||||
steps{
|
||||
|
||||
@@ -96,7 +96,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
4. Build the entire CK library:
|
||||
|
||||
```bash
|
||||
make -j
|
||||
make -j"$(nproc)"
|
||||
```
|
||||
|
||||
5. Install CK:
|
||||
@@ -213,4 +213,4 @@ script/uninstall_precommit.sh
|
||||
```
|
||||
|
||||
If you need to temporarily disable pre-commit hooks, you can add the `--no-verify` option to the
|
||||
`git commit` command.
|
||||
`git commit` command.
|
||||
|
||||
348
TERMINOLOGY.md
348
TERMINOLOGY.md
@@ -1,2 +1,348 @@
|
||||
[Back to the main page](./README.md)
|
||||
# Composable Kernel terminology
|
||||
|
||||
# Composable Kernel Terminology
|
||||
|
||||
This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations.
|
||||
|
||||
---
|
||||
|
||||
## Glossary Index (Alphabetical)
|
||||
|
||||
- [Add+Multiply](#addmultiply)
|
||||
- [Bank Conflict](#bank-conflict)
|
||||
- [Batched GEMM](#batched-gemm)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Block Size](#block-size)
|
||||
- [Block Tile](#block-tile)
|
||||
- [Compute Unit (CU)](#compute-unit-cu)
|
||||
- [Coordinate Transformation Primitives](#coordinate-transformation-primitives)
|
||||
- [CUDA](#cuda)
|
||||
- [Dense Tensor](#dense-tensor)
|
||||
- [Descriptor](#descriptor)
|
||||
- [Device](#device)
|
||||
- [Elementwise](#elementwise)
|
||||
- [Epilogue](#epilogue)
|
||||
- [Fast Changing Dimension](#fast-changing-dimension)
|
||||
- [GEMM](#gemm-general-matrix-multiply)
|
||||
- [GEMV](#gemv)
|
||||
- [Grouped GEMM](#grouped-gemm)
|
||||
- [Global Memory](#global-memory)
|
||||
- [Grid](#grid)
|
||||
- [Host](#host)
|
||||
- [HIP](#hip)
|
||||
- [Inner Dimension](#inner-dimension)
|
||||
- [Inner Product](#inner-product)
|
||||
- [Input/Problem Shape](#inputproblem-shape)
|
||||
- [Kernel](#kernel)
|
||||
- [Launch Parameters](#launch-parameters)
|
||||
- [Load Tile](#load-tile)
|
||||
- [LDS Banks](#lds-banks)
|
||||
- [Matrix Core](#matrix-core)
|
||||
- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add)
|
||||
- [Occupancy](#occupancy)
|
||||
- [Outer Dimension](#outer-dimension)
|
||||
- [Outer Product](#outer-product)
|
||||
- [Pinned Memory](#pinned-memory)
|
||||
- [Pipeline](#pipeline)
|
||||
- [Policy](#policy)
|
||||
- [Problem](#problem)
|
||||
- [Processing Units](#processing-units)
|
||||
- [Reference Kernel](#reference-kernel)
|
||||
- [Regression Test](#regression-test)
|
||||
- [ROCm](#rocm)
|
||||
- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr)
|
||||
- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share)
|
||||
- [SIMT / SIMD](#simt--simd)
|
||||
- [Smoke Test](#smoke-test)
|
||||
- [Sparse Tensor](#sparse-tensor)
|
||||
- [Split-K GEMM](#split-k-gemm)
|
||||
- [Store Tile](#store-tile)
|
||||
- [Thread / Work-item](#thread--work-item)
|
||||
- [Thread Block / Work Group](#thread-block--work-group)
|
||||
- [Vanilla GEMM](#vanilla-gemm)
|
||||
- [Tile](#tile)
|
||||
- [Tile Distribution](#tile-distribution)
|
||||
- [Tile Partitioner](#tile-partitioner)
|
||||
- [Tile Programming API](#tile-programming-api)
|
||||
- [Tile Window](#tile-window)
|
||||
- [User Customized Tile Pipeline](#user-customized-tile-pipeline)
|
||||
- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization)
|
||||
- [Vector](#vector)
|
||||
- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr)
|
||||
- [Warp / Wavefront](#warp--wavefront)
|
||||
- [Wave Tile](#wave-tile)
|
||||
- [XDL Instructions](#xdl-instructions)
|
||||
|
||||
---
|
||||
|
||||
## 1. Hardware and Memory
|
||||
|
||||
### Processing Units
|
||||
The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale.
|
||||
|
||||
### Matrix Core
|
||||
Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores.
|
||||
|
||||
### Compute Unit (CU)
|
||||
AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_.
|
||||
|
||||
### Matrix Fused Multiply-Add (MFMA)
|
||||
AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance.
|
||||
|
||||
### Registers
|
||||
The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file.
|
||||
|
||||
### Vector General Purpose Register (VGPR)
|
||||
Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations.
|
||||
|
||||
### Scalar General Purpose Register (SGPR)
|
||||
Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave.
|
||||
|
||||
### Shared Memory / Local Data Share (LDS)
|
||||
AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict).
|
||||
|
||||
### LDS Banks
|
||||
Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth.
|
||||
|
||||
### Global Memory
|
||||
The main device memory accessible by all threads, offering high capacity but higher latency than shared memory.
|
||||
|
||||
### Pinned Memory
|
||||
Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements.
|
||||
|
||||
### Dense Tensor
|
||||
A tensor in which most elements are nonzero, typically stored in a contiguous block of memory.
|
||||
|
||||
### Sparse Tensor
|
||||
A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices.
|
||||
|
||||
### Host
|
||||
CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation.
|
||||
|
||||
### Device
|
||||
GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators.
|
||||
|
||||
---
|
||||
|
||||
## 2. GPU Programming Model
|
||||
|
||||
### Thread / Work-item
|
||||
AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing.
|
||||
|
||||
### Warp / Wavefront
|
||||
AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp.
|
||||
|
||||
### Thread Block / Work Group
|
||||
AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution.
|
||||
|
||||
### Grid
|
||||
The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU.
|
||||
|
||||
### Block Size
|
||||
Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage.
|
||||
|
||||
### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD)
|
||||
SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures.
|
||||
|
||||
### Occupancy
|
||||
The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput.
|
||||
|
||||
---
|
||||
|
||||
## 3. Kernel Structure
|
||||
|
||||
### Kernel
|
||||
A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component.
|
||||
|
||||
### Pipeline
|
||||
A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around.
|
||||
|
||||
### Tile Partitioner
|
||||
Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes.
|
||||
|
||||
### Problem
|
||||
Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution).
|
||||
|
||||
### Policy
|
||||
Defines memory access patterns and hardware-specific optimizations.
|
||||
|
||||
### User Customized Tile Pipeline
|
||||
User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points.
|
||||
|
||||
### User Customized Tile Pipeline Optimization
|
||||
Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points.
|
||||
|
||||
### Tile Programming API
|
||||
CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store.
|
||||
|
||||
### Coordinate Transformation Primitives
|
||||
CK utilities for converting between different coordinate systems (logical, physical, memory layouts).
|
||||
|
||||
### Reference Kernel
|
||||
A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU.
|
||||
|
||||
### Launch Parameters
|
||||
Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance.
|
||||
|
||||
---
|
||||
|
||||
## 4. Memory Access and Data Layout
|
||||
|
||||
### Memory Coalescing
|
||||
An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth.
|
||||
|
||||
### Alignment
|
||||
A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value.
|
||||
|
||||
### Bank Conflict
|
||||
Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth.
|
||||
|
||||
### Padding
|
||||
The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access.
|
||||
|
||||
### Permute/Transpose
|
||||
Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns.
|
||||
|
||||
### Host-Device Transfer
|
||||
The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations.
|
||||
|
||||
### Stride
|
||||
The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step.
|
||||
|
||||
### Dilation
|
||||
The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size.
|
||||
|
||||
### Im2Col/Col2Im
|
||||
Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout.
|
||||
|
||||
### Fast Changing Dimension
|
||||
Innermost dimension that changes fastest in memory layout.
|
||||
|
||||
### Outer Dimension
|
||||
Slower-changing dimension in memory layout.
|
||||
|
||||
### Inner Dimension
|
||||
Faster-changing dimension in memory layout.
|
||||
|
||||
---
|
||||
|
||||
## 5. Tile-Based Computing and Data Structures
|
||||
|
||||
### Tile
|
||||
A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms.
|
||||
|
||||
### Block Tile
|
||||
Memory tile processed by a work group (thread block).
|
||||
|
||||
### Wave Tile
|
||||
Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution.
|
||||
|
||||
### Tile Distribution
|
||||
Hierarchical data mapping from work-items to data in memory.
|
||||
|
||||
### Tile Window
|
||||
Viewport into a larger tensor that defines the current tile's position and boundaries for computation.
|
||||
|
||||
### Load Tile
|
||||
Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns.
|
||||
|
||||
### Store Tile
|
||||
Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns.
|
||||
|
||||
### Descriptor
|
||||
Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations.
|
||||
|
||||
### Input/Problem Shape
|
||||
Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM).
|
||||
|
||||
### Vector
|
||||
Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware.
|
||||
|
||||
---
|
||||
|
||||
## 6. Kernel Operations and Optimization
|
||||
|
||||
### Elementwise
|
||||
Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access.
|
||||
|
||||
### Epilogue
|
||||
The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs.
|
||||
|
||||
### Add+Multiply
|
||||
A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers.
|
||||
|
||||
---
|
||||
|
||||
## 7. Linear Algebra and ML Operations
|
||||
|
||||
### General Matrix Multiply (GEMM)
|
||||
Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C.
|
||||
|
||||
### "Vanilla" GEMM (Naive GEMM) Kernel
|
||||
The **vanilla GEMM** is the simplest form of GEMM in CK. It:
|
||||
- Takes input matrices **A** and **B**
|
||||
- Multiplies them to produce output matrix **C**
|
||||
|
||||
This is the **baseline** or **building block** GEMM that all other complex versions expand upon.
|
||||
|
||||
### Grouped GEMM (GGEMMs)
|
||||
|
||||
A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel.
|
||||
|
||||
### Batched GEMM
|
||||
A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape.
|
||||
|
||||
### Split-K GEMM
|
||||
A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications.
|
||||
|
||||
### GEMV
|
||||
The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing.
|
||||
|
||||
### Inner Product
|
||||
Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar.
|
||||
|
||||
### Outer Product
|
||||
The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms.
|
||||
|
||||
### Norm
|
||||
A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization.
|
||||
|
||||
---
|
||||
|
||||
## 8. Testing, Build, and Infrastructure
|
||||
|
||||
### Regression Test
|
||||
Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942.
|
||||
|
||||
### Smoke Test
|
||||
Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942.
|
||||
|
||||
---
|
||||
|
||||
## 9. Low-Level Instructions and Optimizations
|
||||
|
||||
### eXtensible Data Language (XDL) Instructions
|
||||
eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks.
|
||||
|
||||
---
|
||||
|
||||
## 10. Miscellaneous
|
||||
|
||||
### HIP
|
||||
AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures.
|
||||
|
||||
### CUDA
|
||||
NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources.
|
||||
|
||||
### ROCm
|
||||
AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs.
|
||||
|
||||
---
|
||||
|
||||
## Scientific Context and References
|
||||
|
||||
This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see:
|
||||
- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html)
|
||||
- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html)
|
||||
|
||||
This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles.
|
||||
|
||||
@@ -107,14 +107,14 @@ int execute_conv_fwd()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -130,14 +130,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -105,14 +105,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -109,14 +109,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -111,14 +111,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -59,7 +59,7 @@ int main()
|
||||
SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size);
|
||||
|
||||
std::array<const void*, 2> ab_input = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
std::vector<ck::index_t> abStride = {Stride, 1};
|
||||
std::array<std::vector<ck::index_t>, 2> abStrides = {abStride, abStride};
|
||||
|
||||
|
||||
@@ -68,15 +68,15 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{in.GetDeviceBuffer()},
|
||||
{in.GetDeviceBuffer()},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{out.GetDeviceBuffer()},
|
||||
{out.GetDeviceBuffer()},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce(
|
||||
ck::tensor_operation::element_wise::Scale{scale_wei},
|
||||
{}};
|
||||
auto conv_ok = ConvolutionScale<InDataType,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
wei,
|
||||
conv_out,
|
||||
elementwise_op,
|
||||
@@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor,
|
||||
{
|
||||
std::cout << "\nReduction of spatial dimensions:" << std::endl;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<OutDataType,
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -120,14 +120,14 @@ int execute_conv_fwd_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab()
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++17)
|
||||
add_compile_options(-std=c++20)
|
||||
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
|
||||
@@ -68,3 +68,6 @@ endif()
|
||||
|
||||
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0)
|
||||
target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
add_compile_options(-std=c++20)
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
# TODO: Use object library
|
||||
|
||||
@@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin())
|
||||
}
|
||||
|
||||
template <class Range1, class Range2, class F>
|
||||
inline auto Transform(const Range1& r1, const Range2& r2, F f)
|
||||
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
inline auto Transform(const Range1& r1,
|
||||
const Range2& r2,
|
||||
F f) -> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
{
|
||||
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
|
||||
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
|
||||
|
||||
@@ -142,12 +142,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
x.A = TensorDesc{prob.ADataType, prob.ALayout};
|
||||
x.B = TensorDesc{prob.BDataType, prob.BLayout};
|
||||
x.E = TensorDesc{prob.EDataType, prob.ELayout};
|
||||
x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) {
|
||||
return TensorDesc{dt, lo};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.Ds = Transform(
|
||||
prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; });
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
|
||||
@@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -16,7 +16,7 @@ struct tmp_dir
|
||||
|
||||
void execute(const std::string& cmd) const;
|
||||
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir& operator=(tmp_dir const&) = delete;
|
||||
|
||||
~tmp_dir();
|
||||
|
||||
@@ -94,7 +94,7 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
|
||||
assert(not srcs.empty());
|
||||
tmp_dir td{"compile"};
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -std=c++20";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
std::string out;
|
||||
|
||||
@@ -278,7 +278,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src
|
||||
static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -std=c++20";
|
||||
options.flags += " -DCK_CODE_GEN_RTC";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options);
|
||||
|
||||
@@ -19,7 +19,6 @@ Getting started
|
||||
build the library. You can also find some of this information in the
|
||||
`README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
|
||||
on the project's GitHub page.
|
||||
#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities.
|
||||
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
|
||||
from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities.
|
||||
#. **General information:** For broader information about AMD products, consider exploring the
|
||||
|
||||
@@ -39,6 +39,7 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab
|
||||
* :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>`
|
||||
* :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>`
|
||||
* :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>`
|
||||
* :doc:`Composable Kernel glossary <./reference/Composable-Kernel-Glossary>`
|
||||
|
||||
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/contributing.html>`_.
|
||||
|
||||
|
||||
@@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel:
|
||||
* zlib1g-dev
|
||||
* libzstd-dev
|
||||
* openssh-server
|
||||
* clang-format-12
|
||||
* clang-format-18
|
||||
|
||||
256
docs/reference/Composable-Kernel-Glossary.rst
Normal file
256
docs/reference/Composable-Kernel-Glossary.rst
Normal file
@@ -0,0 +1,256 @@
|
||||
.. meta::
|
||||
:description: Composable Kernel glossary of terms
|
||||
:keywords: composable kernel, glossary
|
||||
|
||||
***************************************************
|
||||
Composable Kernel glossary
|
||||
|
||||
***************************************************
|
||||
|
||||
.. glossary::
|
||||
:sorted:
|
||||
|
||||
arithmetic logic unit
|
||||
The arithmetic logic unit (ALU) is the GPU component responsible for arithmetic and logic operations.
|
||||
|
||||
compute unit
|
||||
The compute unit (CU) is the parallel vector processor in an AMD GPU with multiple :term:`ALUs<arithmetic logic unit>`. Each compute unit will run all the :term:`wavefronts<wavefront>` in a :term:`work group>`. A compute unit is equivalent to NVIDIA's streaming multiprocessor.
|
||||
|
||||
matrix core
|
||||
A matrix core is a specialized GPU unit that accelerate matrix operations for AI and deep learning tasks. A GPU contains multiple matrix cores.
|
||||
|
||||
register
|
||||
Registers are the fastest tier of memory. They're used for storing temporary values during computations and are private to the :term:`work-items<work-item>` that use them.
|
||||
|
||||
VGPR
|
||||
See :term:`vector general purpose register`.
|
||||
|
||||
vector general purpose register
|
||||
A vector general purpose register (VGPR) is a :term:`register` that stores individual thread data. Each thread in a :term:`wave<wavefront>` has its own set of VGPRs for private variables and calculations.
|
||||
|
||||
SGPR
|
||||
See :term:`scalar general purpose register`.
|
||||
|
||||
scalar general purpose register
|
||||
A scalar general purpose register (SGPR) is a :term:`register` shared by all the :term:`work items<work item>` in a :term:`wave<wavefront>`. SGPRs are used for constants, addresses, and control flow common across the entire wave.
|
||||
|
||||
LDS
|
||||
See :term:`local data share`.
|
||||
|
||||
local data share
|
||||
Local data share (LDS) is high-bandwidth, low-latency on-chip memory accessible to all the :term:`work-items<work-item>` in a :term:`work group`. LDS is equivalent to NVIDIA's shared memory.
|
||||
|
||||
LDS banks
|
||||
LDS banks are a type of memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. LDS banks are used to prevent memory access conflicts and improve bandwidth when LDS is used.
|
||||
|
||||
global memory
|
||||
The main device memory accessible by all threads, offering high capacity but higher latency than shared memory.
|
||||
|
||||
pinned memory
|
||||
Pinned memory is :term:`host` memory that is page-locked to accelerate transfers between the CPU and GPU.
|
||||
|
||||
dense tensor
|
||||
A dense tensor is a tensor where most of its elements are non-zero. Dense tensors are typically stored in a contiguous block of memory.
|
||||
|
||||
sparse tensor
|
||||
A sparse tensor is a tensor where most of its elements are zero. Typically only the non-zero elements of a sparse tensor and their indices are stored.
|
||||
|
||||
host
|
||||
Host refers to the CPU and the main memory system that manages GPU execution. The host is responsible for launching kernels, transferring data, and coordinating overall computation.
|
||||
|
||||
device
|
||||
Device refers to the GPU hardware that runs parallel kernels. The device contains the :term:`compute units<compute unit>`, memory hierarchy, and specialized accelerators.
|
||||
|
||||
work-item
|
||||
A work-item is the smallest unit of parallel execution. A work-item runs a single independent instruction stream on a single data element. A work-item is equivalent to an NVIDIA thread.
|
||||
|
||||
wavefront
|
||||
Also referred to as a wave, a wavefront is a group of :term:`work-items<work-item>` that run the same instruction. A wavefront is equivalent to an NVIDIA warp.
|
||||
|
||||
work group
|
||||
A work group is a collection of :term:`work-items<work-item>` that can synchronize and share memory. A work group is equivalent to NVIDIA's thread block.
|
||||
|
||||
grid
|
||||
A grid is a collection of :term:`work groups<work group>` that run a kernel. Each work group within the grid operates independently and can be scheduled on a different :term:`compute unit`. A grid can be organized into one, two, or three dimensions. A grid is equivalent to an NVIDIA thread block.
|
||||
|
||||
block Size
|
||||
The block size is the number of :term:`work-items<work-item>` in a :term:`compute unit`.
|
||||
|
||||
SIMT
|
||||
See :term:`single-instruction, multi-thread`
|
||||
|
||||
single-instruction, multi-thread
|
||||
Single-instruction, multi-thread (SIMT) is a parallel computing model where all the :term:`work-items<work-item>` within a :term:`wavefront` run the same instruction on different data.
|
||||
|
||||
SIMD
|
||||
See :term:`single-instruction, multi-data`
|
||||
|
||||
single-instruction, multi-data
|
||||
Single-instruction, multi-data (SIMD) is a parallel computing model where the same instruction is run with different data simultaneously.
|
||||
|
||||
occupancy
|
||||
The ratio of active :term:`wavefronts<wavefront>` to the maximum possible number of wavefronts.
|
||||
|
||||
kernel
|
||||
A kernel is a function that runs an :term:`operation` or a collection of operations. A kernel will run in parallel on several :term:`work-items<work-item>` across the GPU. In Composable Kernel, kernels require :term:`pipelines<pipeline>`.
|
||||
|
||||
operation
|
||||
An operation is a computation on input data.
|
||||
|
||||
pipeline
|
||||
A Composable Kernel pipeline schedules the sequence of operations for a :term:`kernel`, such as the data loading, computation, and storage phases. A pipeline consists of a :term:`problem` and a :term:`policy`.
|
||||
|
||||
tile partitioner
|
||||
The tile partitioner defines the mapping between the :term:`problem` dimensions and GPU hierarchy. It specifies :term:`workgroup`-level :term:`tile` sizes and determines :term:`grid` dimensions by dividing the problem size by the tile sizes.
|
||||
|
||||
problem
|
||||
The problem is the part of the :term:`pipeline` that defines input and output shapes, data types, and mathematical :term:`operations<operation>`.
|
||||
|
||||
policy
|
||||
The policy is the part of the :term:`pipeline` that defines memory access patterns and hardware-specific optimizations.
|
||||
|
||||
user customized tile pipeline
|
||||
A customized :term:`tile` :term:`pipeline` that combines custom :term:`problem` and :term:`policy` components for specialized computations.
|
||||
|
||||
user customized tile pipeline optimization
|
||||
The process of tuning the :term:`tile` size, memory access pattern, and hardware utilization for specific workloads.
|
||||
|
||||
tile programming API
|
||||
The :term:`tile` programming API is Composable Kernel's high-level interface for defining tile-based computations with predefined hardware mappings for data loading and storing.
|
||||
|
||||
coordinate transformation primitives
|
||||
Coordinate transformation primitives are Composable Kernel utilities for converting between different coordinate systems.
|
||||
|
||||
reference kernel
|
||||
A reference :term:`kernel` is a baseline kernel implementation used to verify correctness and performance. Composable Kernel makes two reference kernels, one for CPU and one for GPU, available.
|
||||
|
||||
launch parameters
|
||||
Launch parameters are the configuration values, such as :term:`grid` and :term:`block size`, that determine how a :term:`kernel` is mapped to hardware resources.
|
||||
|
||||
memory coalescing
|
||||
Memory coalescing is an optimization strategy where consecutive :term:`work-items<work-item>` access consecutive memory addresses in such a way that a single memory transaction serves multiple work-items.
|
||||
|
||||
alignment
|
||||
Alignment is a memory management strategy where data structures are stored at addresses that are multiples of a specific value.
|
||||
|
||||
|
||||
bank conflict
|
||||
A bank conflict occurs when multiple :term:`work-items<work-item>` in a :term:`wavefront` access different addresses that map to the same shared memory bank.
|
||||
|
||||
padding
|
||||
Padding is the addition of extra elements, often zeros, to tensor edges in order to control output size in convolution and pooling, or to align data for memory access.
|
||||
|
||||
transpose
|
||||
Transpose is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
|
||||
|
||||
permute
|
||||
Permute is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
|
||||
|
||||
host-device transfer
|
||||
A host-device transfer is the process of moving data between :term:`host` and :term:`device` memory.
|
||||
|
||||
stride
|
||||
A stride is the step size to move from one element to the next in a specific dimension of a tensor or matrix. In convolution and pooling, the stride determines how far the :term:`kernel` moves at each step.
|
||||
|
||||
dilation
|
||||
Dilation is the spacing between :term:`kernel` elements in convolution :term:`operations<operation>`, allowing the receptive field to grow without increasing kernel size.
|
||||
|
||||
Im2Col
|
||||
Im2Col is a data transformation technique that converts image data to column format.
|
||||
|
||||
Col2Im
|
||||
Col2Im is a data transformation technique that converts column data to image format.
|
||||
|
||||
fast changing dimension
|
||||
The fast changing dimension is the innermost dimension in memory layout.
|
||||
|
||||
outer dimension
|
||||
The outer dimension is the slower-changing dimension in memory layout.
|
||||
|
||||
inner dimension
|
||||
The inner dimension is the faster-changing dimension in memory layout.
|
||||
|
||||
tile
|
||||
A tile is a sub-region of a tensor or matrix that is processed by a :term:`work group` or :term:`work-item`. Rectangular data blocks are the unit of computation and memory transfer in Composable Kernel, and are the basis for tiled algorithms.
|
||||
|
||||
block tile
|
||||
A block tile is a memory :term:`tile` processed by a :term:`work group`.
|
||||
|
||||
wave tile
|
||||
A wave :term:`tile` is a sub-tile processed by a single :term:`wavefront` within a :term:`work group`. The wave tile is the base level granularity of a :term:`single-instruction, multi-thread (SIMD)<single-instruction, multi-thread>` model.
|
||||
|
||||
tile distribution
|
||||
The tile distribution is the hierarchical data mapping from :term:`work-items<work-item>` to data in memory.
|
||||
|
||||
tile window
|
||||
Viewport into a larger tensor that defines the current tile's position and boundaries for computation.
|
||||
|
||||
load tile
|
||||
Load tile is an operation that transfers data from :term:`global memory` or the :term:`load data share` to :term:`vector general purpose registers<vector general purpose register>`.
|
||||
|
||||
store tile
|
||||
Store tile is an operation that transfers data from :term:`vector general purpose registers<vector general purpose register>` to :term:`global memory` or the :term:`load data share`.
|
||||
|
||||
descriptor
|
||||
Metadata structure that defines :term:`tile` properties, memory layouts, and coordinate transformations for Composable Kernel :term:`operations<operation>`.
|
||||
|
||||
input
|
||||
See :term:`problem shape`.
|
||||
|
||||
problem shape
|
||||
The problem shape defines the dimensions and data types of input tensors that define the :term:`problem`.
|
||||
|
||||
vector
|
||||
The vector is the smallest data unit processed by an individual :term:`work-item`. A vectors is typically four to sixteen elements, depending on data type and hardware.
|
||||
|
||||
elementwise
|
||||
An elementwise :term:`operation` is an operation applied to each tensor element independently.
|
||||
|
||||
epilogue
|
||||
The epilogue is the final stage of a kernel. Activation functions, bias, and other post-processing steps are applied in the epilogue.
|
||||
|
||||
Add+Multiply
|
||||
See :term:`fused add multiply`.
|
||||
|
||||
fused add multiply
|
||||
A common fused :term:`operation` in machine language and linear algebra, where an :term:`elementwise` addition is immediately followed by a multiplication. Fused add multiply is often used for bias and scaling in neural network layers.
|
||||
|
||||
MFMA
|
||||
See :term:`matrix fused multiply-add`.
|
||||
|
||||
matrix fused multiply-add
|
||||
Matrix fused multiply-add (MFMA) is a :term:`matrix core` instruction for GEMM :term:`operations<operation>`.
|
||||
|
||||
GEMM
|
||||
See :term:`general matrix multiply`.
|
||||
|
||||
general matrix multiply
|
||||
A general matrix multiply (GEMM) is a Core matrix :term:`operation` in linear algebra and deep learning. A GEMM is defined as :math:`C = {\alpha}AB + {\beta}C`, where :math:`A`, :math:`B`, and :math:`C` are matrices, and :math:`\alpha` and :math:`\beta` are scalars.
|
||||
|
||||
VGEMM
|
||||
See :term:`naive GEMM`.
|
||||
|
||||
vanilla GEMM
|
||||
See :term:`naive GEMM`.
|
||||
|
||||
naive GEMM
|
||||
The naive GEMM, sometimes referred to as a vanilla GEMM or VGEMM, is the simplest form of :term:`GEMM` in Composable Kernel. The naive GEMM is defined as :math:`C = AB`, where :math:`A`, :math:`B`, and :math:`C` are matrices. The naive GEMM is the baseline GEMM that all other GEMM :term:`operations<operation>` build on.
|
||||
|
||||
GGEMM
|
||||
See :term:`grouped GEMM`.
|
||||
|
||||
grouped GEMM
|
||||
A :term:`kernel` that calls multiple :term:`VGEMMs<naive GEMM>`. Each call can have a different :term:`problem shape`.
|
||||
|
||||
batched GEMM
|
||||
A :term:`kernel` that calls :term:`VGEMMs<naive GEMM>` with different batches of data. All the data batches have the same :term:`problem shape`.
|
||||
|
||||
Split-K GEMM
|
||||
Split-K GEMM is a parallelization strategy that partitions the reduction dimension (K) of a :term:`GEMM` across multiple :term:`compute units<compute unit>`, increasing parallelism for large matrix multiplications.
|
||||
|
||||
GEMV
|
||||
See :term:`general matrix vector multiplication`
|
||||
|
||||
general matrix vector multiplication
|
||||
General matrix vector multiplication (GEMV) is an :term:`operation` where a matrix is multiplied by a vector, producing another vector.
|
||||
|
||||
@@ -34,8 +34,14 @@ subtrees:
|
||||
title: Composable Kernel vector utilities
|
||||
- file: reference/Composable-Kernel-wrapper.rst
|
||||
title: Composable Kernel wrapper
|
||||
- file: doxygen/html/namespace_c_k.rst
|
||||
title: CK API reference
|
||||
- file: doxygen/html/namespaceck__tile.rst
|
||||
title: CK Tile API reference
|
||||
- file: doxygen/html/annotated.rst
|
||||
title: Composable Kernel class list
|
||||
title: Full API class list
|
||||
- file: reference/Composable-Kernel-Glossary.rst
|
||||
title: Glossary
|
||||
|
||||
- caption: About
|
||||
entries:
|
||||
|
||||
@@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
|
||||
|
||||
367
example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp
Normal file
367
example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp
Normal file
@@ -0,0 +1,367 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using BScaleDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = true;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_N = 1;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
KPerBlock, 8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
CDataType, CDataType, PermuteA, PermuteB>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// weight permute
|
||||
if constexpr(PermuteB)
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_k_n_permute(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
b1_scale_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
Scale_Stride_BN,
|
||||
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string device_name = ck::get_device_name();
|
||||
if(!(device_name.find("gfx11") != std::string::npos ||
|
||||
device_name.find("gfx12") != std::string::npos))
|
||||
{
|
||||
std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<float> b_k_n_dequant({K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
|
||||
b_k_n_dequant(k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
#else
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -56,10 +56,10 @@ using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
template <typename DataType>
|
||||
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
@@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
try
|
||||
{
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
using InOutDataTypeInDevice = typename std::
|
||||
conditional<std::is_same<InOutDataType, int4_t>::value, int8_t, InOutDataType>::type;
|
||||
#else
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
#endif
|
||||
|
||||
using DeviceReduceInstance =
|
||||
|
||||
@@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{},
|
||||
{},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
|
||||
@@ -207,7 +207,7 @@ int main(int argc, char* argv[])
|
||||
auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
@@ -216,9 +216,9 @@ int main(int argc, char* argv[])
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{},
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
BatchCount);
|
||||
|
||||
@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example()
|
||||
{0, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|
||||
|------------------|
|
||||
Gemm0
|
||||
|-----------------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// clang-format off
|
||||
// #define CK_MHA_USE_RCCR_LAYOUT
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
// #define CK_MHA_USE_WAVE_2
|
||||
// #define CK_MHA_USE_WAVE_4
|
||||
// #define CK_MHA_USE_WAVE_8
|
||||
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Col, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
>;
|
||||
#else
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_2
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_4
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
>;
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = BF16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using B0DataType = ck::f8_t;
|
||||
using B1DataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::f8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using B0DataType = int8_t;
|
||||
using B1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 113;
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
ck::index_t N = 480; // Must be multiple of 8 even with padding.
|
||||
#else
|
||||
ck::index_t N = 477;
|
||||
#endif
|
||||
ck::index_t K = 200; // Must be multiple of 8 even with padding.
|
||||
ck::index_t O = 208; // Must be multiple of 8 even with padding.
|
||||
ck::index_t G = 91; // Batch
|
||||
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G = std::stoi(argv[8]);
|
||||
|
||||
alpha = std::stof(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, G\n");
|
||||
printf("arg9: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
|
||||
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
|
||||
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
|
||||
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
|
||||
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
|
||||
#else
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
|
||||
#endif
|
||||
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
|
||||
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
|
||||
|
||||
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
case 4: // A, B0, B1 1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5: // Rand: b1 b0; unit: a
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 6: // Rand: a b0 ; unit: B1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 7: // Rand: a b1 ; unit: b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 8: // Rand: a ; unit: b0 b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 9: // Rand: b0 ; unit: a b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 10: // Rand: b1 ; unit: a b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
float best_perf = .0;
|
||||
float best_time = .0;
|
||||
int not_pass = 0;
|
||||
std::string best_kernel = "";
|
||||
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
|
||||
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
|
||||
|
||||
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
|
||||
auto gemm = DeviceMHAInstance{};
|
||||
auto invoker_ptr = gemm.MakeInvokerPointer();
|
||||
auto argument_ptr =
|
||||
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
G, // Batch,
|
||||
a_g_m_k_strides[1], // StrideA,
|
||||
b0_g_n_k_strides[1], // StrideB0,
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
b1_g_o_n_strides[1], // StrideB1,
|
||||
#else
|
||||
b1_g_o_n_strides[2], // StrideB1,
|
||||
#endif
|
||||
c_g_m_o_strides[1], // StrideC,
|
||||
a_g_m_k_strides[0], // BatchStrideA
|
||||
b0_g_n_k_strides[0], // BatchStrideB0
|
||||
b1_g_o_n_strides[0], // BatchStrideB1
|
||||
c_g_m_o_strides[0], // BatchStrideC
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
G;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
if(tflops > best_perf)
|
||||
{
|
||||
best_perf = tflops;
|
||||
best_time = ave_time * 1000;
|
||||
best_kernel = gemm.GetTypeString();
|
||||
}
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
|
||||
Tensor<B0DataType> b0_g_k_n({G, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({G, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
|
||||
|
||||
// permute
|
||||
b0_g_n_k.ForEach(
|
||||
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
b1_g_o_n.ForEach(
|
||||
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
// Passthrough instead of softmax, DOES involve data type conversion.
|
||||
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
|
||||
});
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
// when BF16 is taken, set absolute error and relative error to 0.01
|
||||
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
rtol = 1e-2;
|
||||
atol = 1e-2;
|
||||
}
|
||||
|
||||
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
|
||||
c_g_m_o_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
printf("Verification: %s, Pass: %s\n",
|
||||
do_verification ? "ON" : "OFF",
|
||||
this_run_verification ? "YES" : "NO");
|
||||
|
||||
if(!this_run_verification)
|
||||
{
|
||||
not_pass = 1;
|
||||
printf("%d th MHA instance verification Failed \n", i.value);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
|
||||
<< ", O: " << O << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
|
||||
<< " us" << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
return not_pass;
|
||||
}
|
||||
@@ -126,10 +126,10 @@ int run(int argc, char* argv[])
|
||||
|
||||
if(i < 4)
|
||||
{
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", "
|
||||
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
|
||||
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
|
||||
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks["
|
||||
<< i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i
|
||||
<< "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i
|
||||
<< "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(init_method)
|
||||
|
||||
@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
|
||||
@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -129,11 +129,11 @@ int main()
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
out_dev.GetDeviceBuffer(),
|
||||
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
current_dim,
|
||||
|
||||
@@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc,
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_params = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept
|
||||
}
|
||||
|
||||
template <typename Axes>
|
||||
inline auto is_valid_axes(const Axes& axes)
|
||||
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
inline auto
|
||||
is_valid_axes(const Axes& axes) -> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
{
|
||||
using std::empty;
|
||||
if(empty(axes))
|
||||
@@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes)
|
||||
}
|
||||
|
||||
template <typename Shape, typename Indices>
|
||||
auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
auto advance_indices(const Shape& shape, Indices& indices)
|
||||
-> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
{
|
||||
using std::size;
|
||||
if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
|
||||
|
||||
@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
|
||||
{0, 0, 0, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 4}, // reduction dimension: [H, W, C]
|
||||
1e-6,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -152,7 +152,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::array<const void*, 1> inputs = {input_dev_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 2> outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(),
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::cout << "Input: " << input.mDesc << std::endl;
|
||||
std::cout << "Scale: " << scale << std::endl;
|
||||
@@ -164,8 +164,8 @@ int main(int argc, char* argv[])
|
||||
auto launch_transpose_scale = [&]() {
|
||||
auto transposeScale = DeviceElementwisePermuteInstance{};
|
||||
auto argument = transposeScale.MakeArgumentPointer(dims,
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
inputs,
|
||||
outputs,
|
||||
ScalePassThrough{scale});
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -213,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -194,9 +194,9 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 2>{b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 0>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths},
|
||||
|
||||
@@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
auto device_ew_scale = DeviceElementwiseScale{};
|
||||
auto scale_invoker = device_ew_scale.MakeInvoker();
|
||||
auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths,
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
scale_convert);
|
||||
|
||||
if(!device_ew_scale.IsSupportedArgument(scale_argument))
|
||||
|
||||
@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
|
||||
{0, W * C, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 3},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -31,7 +31,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
|
||||
endif()
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
@@ -39,22 +39,22 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
set(BLOCKSCALE_GEMM_OPTIONS )
|
||||
check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP)
|
||||
check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION)
|
||||
|
||||
if(hip_VERSION_FLAT LESS 600443483 OR hip_VERSION_FLAT GREATER_EQUAL 700000000)
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
|
||||
endif()
|
||||
else()
|
||||
if(HAS_MISCHED_BOTTOMUP)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --misched-bottomup=1")
|
||||
elseif(HAS_MISCHED_PRERA_DIRECTION)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-prera-direction=bottomup")
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --misched-prera-direction=bottomup")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -62,7 +62,6 @@ check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupa
|
||||
if(HAS_MAX_OCCUPANCY_EXPERIMENTAL)
|
||||
list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental)
|
||||
endif()
|
||||
# list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
|
||||
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
|
||||
@@ -71,3 +70,5 @@ example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpresh
|
||||
|
||||
example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS})
|
||||
|
||||
add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp)
|
||||
|
||||
267
example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp
Normal file
267
example/65_gemm_multiply_multiply/gemm_add_add_wmma_fp16.cpp
Normal file
@@ -0,0 +1,267 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
|
||||
using EDataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
|
||||
using ELayout = Row;
|
||||
|
||||
struct AddAdd
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
|
||||
ck::half_t& e, const float& c, const float& d0, const float& d1) const
|
||||
{
|
||||
const float x0_f = c + d0 + d1;
|
||||
|
||||
e = ck::type_convert<ck::half_t>(x0_f);
|
||||
}
|
||||
};
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//#########################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//#########################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| S<C, D..>| | |
|
||||
< A0Layout, B0Layout, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideD = K;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
|
||||
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
|
||||
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
|
||||
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
|
||||
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
|
||||
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
|
||||
break;
|
||||
default:
|
||||
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
|
||||
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_device_buf.ToDevice(a0_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
|
||||
d1_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{StrideD, StrideD},
|
||||
StrideE,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
|
||||
sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
|
||||
B0DataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -184,7 +184,6 @@ int main(int argc, char* argv[])
|
||||
b0_device_buf.ToDevice(b0_k_n.mData.data());
|
||||
d0_device_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_device_buf.ToDevice(d1_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
@@ -220,11 +219,12 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 20, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(D0DataType) * M * N + sizeof(D1DataType) * M * N +
|
||||
sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
@@ -233,8 +233,6 @@ int main(int argc, char* argv[])
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
@@ -357,7 +357,7 @@ int main(int argc, char* argv[])
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_M
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
|
||||
22
example/68_gemm_add/CMakeLists.txt
Normal file
22
example/68_gemm_add/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
add_custom_target(example_gemm_add_xdl)
|
||||
|
||||
add_example_executable(example_gemm_add_xdl_fp16 gemm_add_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_fp16)
|
||||
|
||||
|
||||
add_example_executable(example_gemm_add_xdl_bf16 gemm_add_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_xdl example_gemm_add_xdl_bf16)
|
||||
|
||||
add_custom_target(example_gemm_add_wmma)
|
||||
|
||||
add_example_executable(example_gemm_add_wmma_bf16 gemm_add_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_bf16)
|
||||
|
||||
add_example_executable(example_gemm_add_wmma_fp16 gemm_add_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_wmma example_gemm_add_wmma_fp16)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
114
example/68_gemm_add/common.hpp
Normal file
114
example/68_gemm_add/common.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
using F16_Tuple = ck::Tuple<F16>;
|
||||
using BF16_Tuple = ck::Tuple<BF16>;
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
};
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideD = std::stoi(argv[9]);
|
||||
problem_size.StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD,"
|
||||
"StrideE"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
78
example/68_gemm_add/gemm_add_wmma_bf16.cpp
Normal file
78
example/68_gemm_add/gemm_add_wmma_bf16.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = BF16;
|
||||
using DsDataType = BF16_Tuple;
|
||||
using EDataType = BF16;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DLayout = Row;
|
||||
using DsLayout = Row_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
|
||||
Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add,
|
||||
GemmSpec,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 32, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 4>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_example_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }
|
||||
76
example/68_gemm_add/gemm_add_wmma_fp16.cpp
Normal file
76
example/68_gemm_add/gemm_add_wmma_fp16.cpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = F16_Tuple;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DLayout = Row;
|
||||
using DsLayout = Row_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
|
||||
Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add,
|
||||
GemmSpec,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 32, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 4>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_example_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }
|
||||
82
example/68_gemm_add/gemm_add_xdl_bf16.cpp
Normal file
82
example/68_gemm_add/gemm_add_xdl_bf16.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = BF16;
|
||||
using EDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
#include "run_gemm_add_example_xdl.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }
|
||||
82
example/68_gemm_add/gemm_add_xdl_fp16.cpp
Normal file
82
example/68_gemm_add/gemm_add_xdl_fp16.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = Add;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
#include "run_gemm_add_example_xdl.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_example(argc, argv); }
|
||||
145
example/68_gemm_add/run_gemm_add_example_wmma.inc
Normal file
145
example/68_gemm_add/run_gemm_add_example_wmma.inc
Normal file
@@ -0,0 +1,145 @@
|
||||
#pragma once
|
||||
|
||||
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_add_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config);
|
||||
}
|
||||
144
example/68_gemm_add/run_gemm_add_example_xdl.inc
Normal file
144
example/68_gemm_add/run_gemm_add_example_xdl.inc
Normal file
@@ -0,0 +1,144 @@
|
||||
#pragma once
|
||||
|
||||
bool run_gemm_add(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_add_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) && run_gemm_add(problem_size, config);
|
||||
}
|
||||
15
example/69_gemm_add_relu/CMakeLists.txt
Normal file
15
example/69_gemm_add_relu/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
add_custom_target(example_gemm_add_relu_xdl)
|
||||
|
||||
add_example_executable(example_gemm_add_relu_xdl_fp16 gemm_add_relu_xdl_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_fp16)
|
||||
|
||||
add_example_executable(example_gemm_add_relu_xdl_bf16 gemm_add_relu_xdl_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_relu_xdl example_gemm_add_relu_xdl_bf16)
|
||||
|
||||
add_custom_target(example_gemm_add_relu_wmma)
|
||||
|
||||
add_example_executable(example_gemm_add_relu_wmma_bf16 gemm_add_relu_wmma_bf16.cpp)
|
||||
add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_bf16)
|
||||
|
||||
add_example_executable(example_gemm_add_relu_wmma_fp16 gemm_add_relu_wmma_fp16.cpp)
|
||||
add_example_dependencies(example_gemm_add_relu_wmma example_gemm_add_relu_wmma_fp16)
|
||||
114
example/69_gemm_add_relu/common.hpp
Normal file
114
example/69_gemm_add_relu/common.hpp
Normal file
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddRelu = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
using F16_Tuple = ck::Tuple<F16>;
|
||||
using BF16_Tuple = ck::Tuple<BF16>;
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
};
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideD = std::stoi(argv[9]);
|
||||
problem_size.StrideE = std::stoi(argv[10]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD,"
|
||||
"StrideE"
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
78
example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp
Normal file
78
example/69_gemm_add_relu/gemm_add_relu_wmma_bf16.cpp
Normal file
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = BF16;
|
||||
using DsDataType = BF16_Tuple;
|
||||
using EDataType = BF16;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DLayout = Row;
|
||||
using DsLayout = Row_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddRelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
|
||||
Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu,
|
||||
GemmSpec,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 32, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 4>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_relu_example_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
76
example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp
Normal file
76
example/69_gemm_add_relu/gemm_add_relu_wmma_fp16.cpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using DsDataType = F16_Tuple;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DLayout = Row;
|
||||
using DsLayout = Row_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddRelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
|
||||
Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu,
|
||||
GemmSpec,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
0,
|
||||
S<4, 32, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 4>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#include "run_gemm_add_relu_example_wmma.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
82
example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp
Normal file
82
example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = BF16;
|
||||
using EDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddRelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
#include "run_gemm_add_relu_example_xdl.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
82
example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp
Normal file
82
example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp
Normal file
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AddRelu;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
256,
|
||||
128,
|
||||
32,
|
||||
8,
|
||||
8,
|
||||
32,
|
||||
32,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
S<4, 64, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
8>;
|
||||
|
||||
#include "run_gemm_add_relu_example_xdl.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_add_relu_example(argc, argv); }
|
||||
146
example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc
Normal file
146
example/69_gemm_add_relu/run_gemm_add_relu_example_wmma.inc
Normal file
@@ -0,0 +1,146 @@
|
||||
#pragma once
|
||||
|
||||
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_add_relu_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) &&
|
||||
run_gemm_add_relu(problem_size, config);
|
||||
}
|
||||
145
example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc
Normal file
145
example/69_gemm_add_relu/run_gemm_add_relu_example_xdl.inc
Normal file
@@ -0,0 +1,145 @@
|
||||
#pragma once
|
||||
|
||||
bool run_gemm_add_relu(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideD, StrideE] = problem_size;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_add_relu_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) &&
|
||||
run_gemm_add_relu(problem_size, config);
|
||||
}
|
||||
@@ -24,26 +24,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
@@ -55,81 +56,74 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any DPP examples if DPP_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
|
||||
#Do not build any DPP examples if DPP_KERNELS not set
|
||||
if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp")
|
||||
message(DEBUG "removing dpp example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx")
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx")
|
||||
message(DEBUG "removing microscaling example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8")
|
||||
message(DEBUG "removing fp8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
|
||||
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
|
||||
if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8")
|
||||
message(DEBUG "removing bf8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
set(source_name_list "")
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
list(APPEND source_name_list ${source_name})
|
||||
endforeach()
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4")
|
||||
if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
|
||||
elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
message(DEBUG "trimming targets for ${FILE_NAME}")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
@@ -156,71 +150,71 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
message(DEBUG "adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message(DEBUG "removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
set(test 0)
|
||||
if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message(DEBUG "removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
set(source_name_list "")
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
list(APPEND source_name_list ${source_name})
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
if(source_name_list MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
|
||||
@@ -28,12 +28,14 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api bwd
|
||||
--receipt 3
|
||||
--optdim 32,64,128,256
|
||||
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
|
||||
)
|
||||
|
||||
@@ -142,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
|
||||
|
||||
# add fmha_fwd_v3 example
|
||||
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
|
||||
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
|
||||
)
|
||||
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
|
||||
fmha_fwd_v3.cpp
|
||||
${FMHA_FWD_V3_INSTANCES}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-fgpu-flush-denormals-to-zero
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user