Merge remote-tracking branch 'origin/develop' into samremes/double_buffer_fp8_ab_scale

This commit is contained in:
Sami Remes
2025-09-02 12:13:43 +00:00
294 changed files with 16276 additions and 4639 deletions

112
.github/scripts/therock_configure_ci.py vendored Normal file
View File

@@ -0,0 +1,112 @@
import fnmatch
import json
import os
from pathlib import Path
import subprocess
import sys
from typing import Iterable, Optional, Mapping
def gha_set_output(vars: Mapping[str, str | Path]):
"""Sets values in a step's output parameters.
This appends to the file located at the $GITHUB_OUTPUT environment variable.
See
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
"""
print(f"Setting github output:\n{vars}")
step_output_file = os.getenv("GITHUB_OUTPUT")
if not step_output_file:
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
return
with open(step_output_file, "a") as f:
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
"""Returns the paths of modified files relative to the base reference."""
try:
return subprocess.run(
["git", "diff", "--name-only", base_ref],
stdout=subprocess.PIPE,
check=True,
text=True,
timeout=60,
).stdout.splitlines()
except TimeoutError:
print(
"Computing modified files timed out. Not using PR diff to determine"
" jobs to run.",
file=sys.stderr,
)
return None
# Paths matching any of these patterns are considered to have no influence over
# build or test workflows so any related jobs can be skipped if all paths
# modified by a commit/PR match a pattern in this list.
SKIPPABLE_PATH_PATTERNS = [
"docs/*",
"*.gitignore",
"*.md",
"*.pre-commit-config.*",
"*LICENSE",
'Jenkinsfile',
'.github/ISSUE_TEMPLATE/*',
'.github/CODEOWNERS',
'.github/*.md',
'.github/dependabot.yml',
]
def is_path_skippable(path: str) -> bool:
"""Determines if a given relative path to a file matches any skippable patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if at least one path is not in the skippable set."""
if paths is None:
return False
return any(not is_path_skippable(p) for p in paths)
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if CI workflows should run given a list of modified paths."""
if paths is None:
print("No files were modified, skipping TheRock CI jobs")
return False
paths_set = set(paths)
github_workflows_paths = set(
[p for p in paths if p.startswith(".github/workflows")]
)
other_paths = paths_set - github_workflows_paths
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
print("should_ci_run_given_modified_paths findings:")
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
if contains_other_non_skippable_files:
print("Enabling TheRock CI jobs since a non-skippable path was modified")
return True
else:
print(
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
)
return False
def main(args):
base_ref = args.get("base_ref")
modified_paths = get_modified_paths(base_ref)
print("modified_paths (max 200):", modified_paths[:200])
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
output = {
'enable_therock_ci': json.dumps(enable_jobs)
}
gha_set_output(output)
if __name__ == "__main__":
args = {}
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
main(args)

View File

@@ -21,9 +21,11 @@ jobs:
id-token: write
container:
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
options: -v /runner/config:/home/awsconfig/
env:
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
TEATIME_FORCE_INTERACTIVE: 0
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
steps:
- name: Checkout composable_kernel repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -83,9 +85,9 @@ jobs:
echo "----------"
du -h -d 1 TheRock/build/artifacts
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
- name: Configure AWS Credentials for non-forked repos
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external

View File

@@ -5,6 +5,15 @@ on:
branches:
- develop
workflow_dispatch:
pull_request:
types:
- opened
- synchronize
branches:
- mainline
- release/*
- release-staging/*
- develop
permissions:
contents: read
@@ -18,8 +27,29 @@ concurrency:
cancel-in-progress: true
jobs:
setup:
runs-on: ubuntu-24.04
env:
# The commit being checked out is the merge commit for a PR. Its first
# parent will be the tip of the base branch.
BASE_REF: HEAD^
outputs:
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
steps:
- name: "Checking out repository"
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# We need the parent commit to do a diff
fetch-depth: 2
- name: "Configuring CI options"
id: configure
run: python .github/scripts/therock_configure_ci.py
therock-ci-linux:
name: TheRock CI Linux
needs: setup
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
permissions:
contents: read
id-token: write
@@ -34,6 +64,7 @@ jobs:
name: TheRock CI Summary
if: always()
needs:
- setup
- therock-ci-linux
runs-on: ubuntu-24.04
steps:

View File

@@ -68,6 +68,7 @@ jobs:
VENV_DIR: ${{ env.VENV_DIR }}
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
PLATFORM: ${{ inputs.platform }}
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
- name: Test
timeout-minutes: ${{ matrix.components.timeout_minutes }}

View File

@@ -2,7 +2,7 @@
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## Composable Kernel 1.1.0 for ROCm 7.0.0
## Composable Kernel 1.2.0 for ROCm 7.0.0
### Added
@@ -27,6 +27,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added int8 support for CK_TILE GEMM.
* Added support for elementwise kernel.
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
### Optimized
@@ -48,6 +49,7 @@ None
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
### Known issues

View File

@@ -16,12 +16,21 @@ else()
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
endif()
# Allow user to specify the C++ standard.
# We must support C++17 builds until downstream users are migrated to C++20, but we default to C++20.
set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)")
set(valid_cxx_standards 17 20)
set_property(CACHE CK_CXX_STANDARD PROPERTY STRINGS ${valid_cxx_standards})
if(NOT CK_CXX_STANDARD IN_LIST valid_cxx_standards)
message(FATAL_ERROR "CK_CXX_STANDARD must be one of ${valid_cxx_standards}")
endif()
# Default installation path
if(NOT WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
endif()
set(version 1.1.0)
set(version 1.2.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest)
@@ -221,11 +230,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
set(CK_TILE_USE_WMMA 1)
endif()
# define the macro with the current value (0 or 1)
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
@@ -324,32 +342,19 @@ if(USE_BITINT_EXTENSION_INT4)
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
if(USE_OPT_GFX11)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif()
if(ENABLE_ASM_DUMP)
add_compile_options(--save-temps)
add_compile_options(-Wno-gnu-line-marker)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
endif()
## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)
## C++
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD ${CK_CXX_STANDARD})
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")

23
Dockerfile.pytorch Normal file
View File

@@ -0,0 +1,23 @@
ARG BASE_DOCKER="rocm/pytorch-nightly:latest"
FROM $BASE_DOCKER
ARG CK_PYTORCH_BRANCH="develop"
RUN groupadd -g 109 render && \
usermod -u 1001 jenkins && \
groupmod -g 1001 jenkins && \
cd /tmp/pytorch && \
rm -rf build && \
cd /tmp/pytorch/third_party && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/aiter/3rdparty && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/fbgemm/external && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/flash-attention/csrc && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
chown -R jenkins:jenkins /tmp/pytorch && \
chmod -R a+rwx /tmp/pytorch && \
sudo usermod -aG irc jenkins

237
Jenkinsfile vendored
View File

@@ -192,12 +192,16 @@ def buildDocker(install_prefix){
image_name = "rocm/composable_kernel:ck_aiter"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
}
else{
else if(params.RUN_PYTORCH_TESTS){
image_name = "rocm/composable_kernel:ck_pytorch"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
}
else{
dockerArgs = dockerArgs + " -f Dockerfile . "
}
echo "Build Args: ${dockerArgs}"
try{
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS){
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS || params.RUN_PYTORCH_TESTS){
//force building the new docker if that parameter is true
echo "Building image: ${image_name}"
retimage = docker.build("${image_name}", dockerArgs)
@@ -400,8 +404,9 @@ def cmake_build(Map conf=[:]){
echo "Build packages"
sh 'ninja -j64 package'
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
stash includes: "composablekernel-**.deb", name: "packages"
}
}
else{
@@ -571,50 +576,66 @@ def Build_CK(Map conf=[:]){
python3 -m pytest python/test/test_gen_instances.py
"""
}
dir("build"){
if (params.RUN_FULL_QA && arch == 2 ){
// build deb packages
echo "Build packages"
sh 'ninja package'
archiveArtifacts artifacts: 'composablekernel*.deb'
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb'
sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb'
sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb'
stash includes: "composablekernel-**.deb", name: "packages"
}
}
// run performance tests, stash the logs, results will be processed on the master node
dir("script"){
if (params.RUN_PERFORMANCE_TESTS){
if (params.RUN_FULL_QA && arch == 1){
// run full tests on gfx90a
echo "Run full performance tests"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
archiveArtifacts "perf_batched_gemm.log"
archiveArtifacts "perf_grouped_gemm.log"
archiveArtifacts "perf_grouped_conv_fwd.log"
archiveArtifacts "perf_grouped_conv_bwd_data.log"
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
archiveArtifacts "perf_gemm_bilinear.log"
archiveArtifacts "perf_reduction.log"
archiveArtifacts "perf_splitK_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_mixed_gemm.log"
stash includes: "perf_**.log", name: "perf_log"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
archiveArtifacts "perf_gemm_gfx90a.log"
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
archiveArtifacts "perf_batched_gemm_gfx90a.log"
archiveArtifacts "perf_grouped_gemm_gfx90a.log"
archiveArtifacts "perf_grouped_conv_fwd_gfx90a.log"
archiveArtifacts "perf_grouped_conv_bwd_data_gfx90a.log"
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx90a.log"
archiveArtifacts "perf_gemm_bilinear_gfx90a.log"
archiveArtifacts "perf_reduction_gfx90a.log"
archiveArtifacts "perf_splitK_gemm_gfx90a.log"
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
archiveArtifacts "perf_mixed_gemm_gfx90a.log"
stash includes: "perf_**.log", name: "perf_log_gfx90a"
}
if (params.RUN_FULL_QA && arch == 2){
// run full tests on gfx942
echo "Run full performance tests"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
archiveArtifacts "perf_gemm_gfx942.log"
archiveArtifacts "perf_resnet50_N256_gfx942.log"
archiveArtifacts "perf_resnet50_N4_gfx942.log"
archiveArtifacts "perf_batched_gemm_gfx942.log"
archiveArtifacts "perf_grouped_gemm_gfx942.log"
archiveArtifacts "perf_grouped_conv_fwd_gfx942.log"
archiveArtifacts "perf_grouped_conv_bwd_data_gfx942.log"
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx942.log"
archiveArtifacts "perf_gemm_bilinear_gfx942.log"
archiveArtifacts "perf_reduction_gfx942.log"
archiveArtifacts "perf_splitK_gemm_gfx942.log"
archiveArtifacts "perf_onnx_gemm_gfx942.log"
archiveArtifacts "perf_mixed_gemm_gfx942.log"
stash includes: "perf_**.log", name: "perf_log_gfx942"
}
else if ( arch == 1 ){
// run standard tests on gfx90a
echo "Run performance tests"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
stash includes: "perf_**.log", name: "perf_log"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
archiveArtifacts "perf_gemm_gfx90a.log"
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
stash includes: "perf_**.log", name: "perf_log_gfx90a"
}
else if ( arch == 2 ){
// run standard tests on gfx942
echo "Run performance tests"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
archiveArtifacts "perf_gemm_gfx942.log"
archiveArtifacts "perf_onnx_gemm_gfx942.log"
archiveArtifacts "perf_resnet50_N256_gfx942.log"
archiveArtifacts "perf_resnet50_N4_gfx942.log"
stash includes: "perf_**.log", name: "perf_log_gfx942"
}
// disable performance tests on gfx1030 for now.
//else if ( arch == 3){
@@ -732,29 +753,64 @@ def process_results(Map conf=[:]){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
unstash "perf_fmha_log_gfx942"
}
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx942: ${err.getMessage()}."
}
try{
unstash "perf_fmha_log_gfx90a"
}
catch(Exception err){
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
}
}
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
if (params.BUILD_INSTANCES_ONLY){
// unstash deb packages
unstash "packages"
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
}
else{
// unstash perf files to master
unstash "perf_log"
try{
unstash "perf_log_gfx90a"
}
catch(Exception err){
echo "could not locate the gfx90a performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx942"
}
catch(Exception err){
echo "could not locate the gfx942 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx950"
}
catch(Exception err){
echo "could not locate the gfx950 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx908"
}
catch(Exception err){
echo "could not locate the gfx908 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx11"
}
catch(Exception err){
echo "could not locate the gfx11 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx12"
}
catch(Exception err){
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
echo "could not locate the gfx12 performance logs: ${err.getMessage()}."
}
sh "./process_perf_data.sh"
}
// process the logs
sh "./process_perf_data.sh"
}
}
catch(e){
@@ -819,13 +875,64 @@ def run_aiter_tests(Map conf=[:]){
}
}
def run_pytorch_tests(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
//use the latest pytorch-nightly image
def image = "rocm/composable_kernel:ck_pytorch"
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
def variant = env.STAGE_NAME
def retimage
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
echo "Docker flags: ${dockerOpts}"
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try
{
echo "Pulling image: ${image}"
retimage = docker.image("${image}")
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.pull()
}
}
catch(Exception ex)
{
error "Unable to locate image: ${image}"
}
}
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 45, unit: 'MINUTES'){
try{
sh "rocminfo"
sh "python3 --version"
sh "python3 /tmp/pytorch/tools/amd_build/build_amd.py"
sh "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
}
catch(e){
echo "Throwing error exception while building Pytorch"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
echo "Finished building Pytorch"
}
}
}
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false
0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : ""
pipeline {
agent none
@@ -960,6 +1067,14 @@ pipeline {
name: "RUN_ALL_UNIT_TESTS",
defaultValue: false,
description: "Run all unit tests (default: OFF)")
booleanParam(
name: "RUN_PYTORCH_TESTS",
defaultValue: false,
description: "Try building PYTORCH with latest CK develop branch (default: OFF)")
string(
name: 'ck_pytorch_branch',
defaultValue: 'develop',
description: 'Specify which branch of CK to test with Pytorch (default: develop)')
booleanParam(
name: "RUN_AITER_TESTS",
defaultValue: false,
@@ -1051,6 +1166,24 @@ pipeline {
}
}
}
}
stage("Run Pytorch Tests")
{
parallel
{
stage("Run Pytorch Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_PYTORCH_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942")}
steps{
run_pytorch_tests()
cleanWs()
}
}
}
}
stage("Run AITER Tests")
{
@@ -1107,11 +1240,16 @@ pipeline {
agent{ label rocmnode("gfx90a")}
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cd test_data && \
./generate_test_dataset.sh && \
cd ../script && \
execute_args = """ cd ../build && \
../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
cd ../test_data && \
# Dataset generation modes:
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
./generate_test_dataset.sh half && \
cd ../build && \
./bin/test_grouped_convnd_fwd_dataset_xdl"""
}
steps{
@@ -1306,6 +1444,7 @@ pipeline {
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_CXX_STANDARD="17" \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
@@ -1440,7 +1579,7 @@ pipeline {
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
}
cleanWs()
}
@@ -1517,7 +1656,7 @@ pipeline {
stage("Process results"){
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent { label 'mici' }
steps{

View File

@@ -19,7 +19,6 @@ Getting started
build the library. You can also find some of this information in the
`README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
on the project's GitHub page.
#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities.
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities.
#. **General information:** For broader information about AMD products, consider exploring the

View File

@@ -1,7 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/validation_common.hpp"
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
@@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
try
{
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
M, N, K, StrideA, StrideB, StrideC);
}
catch(const std::runtime_error& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return false;
}
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -144,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -7,7 +7,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fmha_fwd -j
```
This will result in an executable `build/bin/tile_example_fmha_fwd`

View File

@@ -110,9 +110,9 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
@@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_co
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
@@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_confi
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::strea
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}

View File

@@ -110,9 +110,9 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
@@ -385,7 +385,7 @@ class FmhaFwdApiPool:
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][(hdim, hdim_v)]
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'

View File

@@ -60,9 +60,9 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}
@@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}

View File

@@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -0,0 +1,492 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k);
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
ck_tile::fmha_fwd_v3_args args;
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
fmha_traits.hdim_q,
fmha_traits.hdim_v,
fmha_traits.data_type.c_str(),
fmha_traits.is_group_mode,
static_cast<int>(fmha_traits.mask_type),
static_cast<int>(fmha_traits.bias_type),
fmha_traits.has_dbias,
fmha_traits.has_dropout,
fmha_traits.is_store_randval,
fmha_traits.is_deterministic);
fflush(stdout);
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
{
switch(data_type)
{
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
{
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct fmha_fwd_v3_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t batch;
index_t seqlen_q;
index_t seqlen_k;
index_t nhead_q;
index_t nhead_kv;
index_t hdim_qk;
index_t hdim_v;
float softmax_scale;
index_t window_size_left;
index_t window_size_right;
index_t mask_type;
const void* q_ptr;
index_t stride_q;
index_t nhead_stride_q;
index_t batch_stride_q;
const void* k_ptr;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
const void* v_ptr;
index_t stride_v;
index_t nhead_stride_v;
index_t batch_stride_v;
void* o_ptr;
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,159 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "fmha_fwd_v3.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
using fmha_shape = TileFmhaShape<fmha_block_tile,
fmha_block_warps,
fmha_warp_gemm_shape,
fmha_block_warps,
fmha_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
true, // kPadSeqLenK
false, // kPadHeadDimQ
false, // kPadHeadDimV
false, // kStoreLSE
-1 // kBlockPerCu
>;
using fmha_mask = SimplifiedGenericAttentionMask<IsMasking>;
using fmha_pipeline_problem =
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
fmha_shape,
IsVariableSeqlen,
fmha_mask,
fmha_traits>;
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
};
template <typename Kernel>
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
{
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_qk,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_kv,
args.softmax_scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,31 @@
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
VALID=0
for causal in 0 1 ; do
for prec in "fp16" "bf16" ; do
for hdim in 128 ; do
for perm in 0 ; do
if [ $causal -eq 0 ]; then
mask=0
else
mask=b:-1,0
fi
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
done
done
done
done

View File

@@ -9,6 +9,8 @@
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
set -euo pipefail
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type

View File

@@ -1,5 +1,7 @@
#!/bin/sh
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
@@ -17,12 +19,12 @@ for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done

View File

@@ -1,5 +1,7 @@
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
@@ -51,19 +53,18 @@ run_fp16_bf16_tests() {
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {

View File

@@ -42,7 +42,7 @@ return hidden_states, per_token_scale
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`

View File

@@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a)
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -7,7 +7,7 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation

View File

@@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
@@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -65,7 +75,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -81,8 +90,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -100,10 +109,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
<< std::endl;
}
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
@@ -208,7 +208,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -232,7 +231,7 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -279,15 +278,13 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -373,7 +370,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,

4
example/ck_tile/03_gemm/gemm_utils.hpp Executable file → Normal file
View File

@@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#if CK_TILE_USE_WMMA
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#endif
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
@@ -484,7 +486,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -126,7 +125,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -172,15 +171,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};

View File

@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -127,7 +126,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -173,15 +172,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -338,7 +335,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -7,7 +7,7 @@ This folder contains example for Image to Column using ck_tile tile-programming
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`

View File

@@ -55,13 +55,12 @@ float image_to_column(const image_to_column_traits& traits,
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
args.G);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
float ave_time = ck_tile::launch_kernel(
stream_conf,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
stream_conf, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -94,18 +94,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims));
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;

View File

@@ -15,7 +15,7 @@ args:
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_permute -j
```
This will result in an executable `build/bin/tile_example_permute`

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
ck_tile::kentry<1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
static constexpr int kBlockSize = BLOCK_SIZE;
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;

View File

@@ -53,11 +53,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
@@ -69,11 +69,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
@@ -85,11 +85,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -6,7 +6,7 @@ This folder contains example for topk-softmax kernel using ck_tile tile-programm
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`

View File

@@ -13,11 +13,11 @@
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;

View File

@@ -6,7 +6,7 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_rmsnorm2d_fwd -j
```
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`

View File

@@ -138,12 +138,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -249,7 +249,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -257,7 +257,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -6,7 +6,7 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j
```
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`

View File

@@ -136,12 +136,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -6,7 +6,7 @@ This folder contains example for smoothquant using ck_tile tile-programming impl
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_smoothquant -j
```
This will result in an executable `build/bin/tile_smoothquant`

View File

@@ -126,12 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -50,7 +50,7 @@ float smoothquant_(const S& s, A a)
using Kernel = ck_tile::Smoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -58,5 +58,5 @@ float smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -6,7 +6,7 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`

View File

@@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
@@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)

View File

@@ -9,7 +9,7 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`

View File

@@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a)
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -53,7 +53,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
const dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
static int printed = 0;
@@ -66,5 +66,5 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -213,7 +213,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -231,7 +231,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -287,7 +287,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \

View File

@@ -7,7 +7,7 @@ This folder contains example for batched GEMM using ck_tile tile-programming imp
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable `build/bin/tile_example_batched_gemm`

View File

@@ -142,7 +142,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -156,8 +155,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -176,7 +175,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -148,7 +148,7 @@ All the necessary parameters are set, the tiling is computed, the GEMM pipeline
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
```

View File

@@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
void* kargs_ptr,
bool splitk)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -53,8 +48,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
float ave_time{0};
@@ -82,7 +75,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -92,9 +84,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
@@ -105,7 +97,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,

View File

@@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the flatmm calculation
make tile_example_flatmm_basic -j
```

View File

@@ -101,7 +101,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
@@ -119,8 +118,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -171,15 +170,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};

View File

@@ -40,9 +40,11 @@ template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
@@ -213,6 +215,16 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 3)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 4)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
}
else
{
a_host.SetZero();

View File

@@ -8,7 +8,7 @@ This folder contains example for Multiple D GEMM using ck_tile tile-programming
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_d_fp16 -j
```

View File

@@ -146,7 +146,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -160,8 +159,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -176,7 +175,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -6,3 +6,6 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,215 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
}
#include "run_grouped_convolution_bwd_data_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_data_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); }

View File

@@ -78,7 +78,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -98,8 +97,8 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -123,7 +122,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
float ave_time = ck_tile::launch_kernel_time_mask(
s,
Kernel::Preprocess(kargs, s),
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -77,7 +77,6 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -97,8 +96,8 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -120,7 +119,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -0,0 +1,186 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_warmup,
int n_repeat)
{
float ave_time = grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_bwd_data_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<WeiDataType>{-1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
}
else
{
weight.SetZero();
output.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.SetZero();
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
input_dev_buf.FromDevice(input.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
input_host_ref.SetZero();
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(input,
input_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
throw std::runtime_error("Unsupported gpu verification !!!");
}
return pass;
}

View File

@@ -167,17 +167,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Run the kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,

View File

@@ -112,17 +112,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -99,17 +99,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -6,7 +6,7 @@ This folder contains example for batched Transpose using ck_tile tile-programmin
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_batched_transpose -j
```

View File

@@ -74,8 +74,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
const dim3 grids = kernel::GridSize(a);
const dim3 blocks = kernel::BlockSize();
printf("Pipeline: %d\n", Config::kPipelineId);
printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z);
@@ -96,8 +96,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
printf("Launching Kernel...\n");
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs));
printf("Kernel finished...\n");

View File

@@ -8,9 +8,8 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -7,9 +7,10 @@ This folder contains example for Block Scale GEMM using ck_tile tile-programming
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The aquant pipeline method on the gemm calculation
make tile_example_gemm_aquant_basic -j
make tile_example_gemm_bquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,29 +20,26 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +48,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -68,7 +69,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
constexpr bool transposed_warp_gemm = true;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
@@ -82,6 +83,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
@@ -96,7 +98,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -111,8 +112,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
@@ -136,7 +137,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
@@ -187,13 +188,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -201,32 +203,18 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
@@ -235,4 +223,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,8 +20,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -82,6 +85,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
@@ -96,7 +100,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -111,8 +114,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
@@ -136,7 +139,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
@@ -187,13 +190,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -201,7 +205,7 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
@@ -210,29 +214,18 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
}

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
;
}
#include "run_gemm_bquant_example.inc"
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for B.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -11,11 +11,9 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -83,200 +81,37 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct GemmConfigDecode : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -288,18 +123,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
struct GemmConfigPreshufle_AQ : public GemmConfigBase
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
@@ -314,9 +146,11 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
@@ -332,176 +166,6 @@ struct GemmQuantTypeConfig
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
@@ -559,55 +223,6 @@ struct DataTypeTraits<ck_tile::int8_t>
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -1,3 +1,4 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -31,7 +32,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -40,8 +42,7 @@ template <typename ADataType,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -73,7 +74,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -82,8 +84,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ALayout,
BLayout,
CLayout,
QuantGroupSize,
Preshuffle>(
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
@@ -206,7 +207,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(GemmConfig::Preshuffle)
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
@@ -222,7 +223,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -231,22 +233,21 @@ int run_gemm_example_with_layouts(int argc,
AQLayout,
BLayout,
CLayout,
QuantGroupSize,
GemmConfig::Preshuffle>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -0,0 +1,286 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <bit>
#include <random>
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>& t, int block_bq_k)
{
if(t.get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int n_ = t.get_lengths()[0];
int bqk_ = t.get_lengths()[1];
if(bqk_ % block_bq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a bqk of multiple times of block_bq_k.");
}
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename BQLayout,
typename DsLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& bq_bqk_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t BQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_BQ,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::BQuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK = BQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_BQ = stride_BQ;
float ave_time = gemm_calc_bquant<GemmConfig,
ADataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ADataType, // computeDatatype
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K +
sizeof(BQDataType) * BQK * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideBQ =" << stride_BQ
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " BQ_Type = " << DataTypeTraits<BQDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using BQDataType = typename TypeConfig::QDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t BQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<BQDataType> bq_bqk_n(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(bq_bqk_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(bq_bqk_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
bq_bqk_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
BDataType,
BQDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
BQLayout,
ck_tile::tuple<>,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
b_k_n_dev_buf,
bq_bqk_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
BQK,
stride_A,
stride_B,
stride_BQ,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -12,7 +12,7 @@ This experimental kernel is intended for novice CK developers. It introduces the
mkdir build && cd build
# you can replace <arch> with the appropriate architecture
# (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# Make the copy kernel executable
make tile_example_copy -j
```

View File

@@ -77,10 +77,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// we intentionally do not use pipeline for this example and let the kernel be composite of
// Problem and Policy
constexpr ck_tile::index_t kBlockSize = Shape::BlockSize;
auto blockSize = Kernel::BlockSize();
// Print configuration information
std::cout << "block size (number of threads per block) " << kBlockSize << std::endl;
std::cout << "block size (number of threads per block) " << blockSize << std::endl;
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
@@ -99,16 +99,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ")" << std::endl;
// Launch kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<kBlockSize, 1>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
float ave_time =
launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<1>(Kernel{},
kGridSize,
blockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
// Calculate and print performance metrics
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;

View File

@@ -27,8 +27,9 @@ struct TileCopyShape
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
// Wave tile dimensions
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
// Block tile dimensions
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
@@ -45,7 +46,6 @@ struct TileCopyShape
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
// Hardware configuration
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
// Configuration validation
@@ -60,8 +60,10 @@ struct TileCopyShape
"Invalid wave configuration for N dimension");
// Ensure wave tile dimensions align with wave size
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
#endif
};
/**
@@ -200,6 +202,19 @@ struct ElementWiseTileCopyKernel
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return kBlockSize / 2;
}
else
{
return kBlockSize;
}
}
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;

View File

@@ -50,10 +50,11 @@
#endif
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \
defined(__gfx9_4_generic__)
#define __gfx9__
#endif
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -52,10 +52,27 @@ inline std::string get_device_name()
}
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
inline bool is_gfx11_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
ck::get_device_name() == "gfx1152";
}
inline bool is_xdl_supported()
{
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|| is_gfx12_supported() || is_gfx11_supported()
#endif
;
}
inline bool is_lds_direct_load_supported()
@@ -67,7 +84,8 @@ inline bool is_lds_direct_load_supported()
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
is_gfx12_supported();
}
inline bool is_gfx101_supported()
@@ -83,18 +101,5 @@ inline bool is_gfx103_supported()
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_gfx11_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
ck::get_device_name() == "gfx1152";
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdexcept>
#include <string>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck {
namespace utils {
template <typename Layout>
inline void
validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride")
{
if(ck::is_same_v<Layout, ck::tensor_layout::gemm::ColumnMajor>)
{
if(stride < M)
{
throw std::runtime_error(
"Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) +
") must be greater than or equal to dim (" + std::to_string(M) + ")");
}
}
else // RowMajor
{
if(stride < N)
{
throw std::runtime_error(
"Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) +
") must be greater than or equal to dim (" + std::to_string(N) + ")");
}
}
}
// Convenience functions for common GEMM patterns
template <typename ALayout, typename BLayout, typename CLayout>
inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
{
validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
validate_gemm_stride<CLayout>(M, N, StrideC, "StrideC");
}
} // namespace utils
} // namespace ck

View File

@@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
@@ -219,6 +218,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -227,6 +227,7 @@ struct BlockwiseGemmXdlops_pipeline_base
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

Some files were not shown because too many files have changed in this diff Show More