Merge branch 'develop' into hstu_attention_mi350_fwd_bwd and change in using ck_tile::make_kernel

This commit is contained in:
Qianfeng Zhang
2025-09-01 07:13:16 +00:00
266 changed files with 9617 additions and 3260 deletions

112
.github/scripts/therock_configure_ci.py vendored Normal file
View File

@@ -0,0 +1,112 @@
import fnmatch
import json
import os
from pathlib import Path
import subprocess
import sys
from typing import Iterable, Optional, Mapping
def gha_set_output(vars: Mapping[str, str | Path]):
"""Sets values in a step's output parameters.
This appends to the file located at the $GITHUB_OUTPUT environment variable.
See
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
"""
print(f"Setting github output:\n{vars}")
step_output_file = os.getenv("GITHUB_OUTPUT")
if not step_output_file:
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
return
with open(step_output_file, "a") as f:
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
"""Returns the paths of modified files relative to the base reference."""
try:
return subprocess.run(
["git", "diff", "--name-only", base_ref],
stdout=subprocess.PIPE,
check=True,
text=True,
timeout=60,
).stdout.splitlines()
except TimeoutError:
print(
"Computing modified files timed out. Not using PR diff to determine"
" jobs to run.",
file=sys.stderr,
)
return None
# Paths matching any of these patterns are considered to have no influence over
# build or test workflows so any related jobs can be skipped if all paths
# modified by a commit/PR match a pattern in this list.
SKIPPABLE_PATH_PATTERNS = [
"docs/*",
"*.gitignore",
"*.md",
"*.pre-commit-config.*",
"*LICENSE",
'Jenkinsfile',
'.github/ISSUE_TEMPLATE/*',
'.github/CODEOWNERS',
'.github/*.md',
'.github/dependabot.yml',
]
def is_path_skippable(path: str) -> bool:
"""Determines if a given relative path to a file matches any skippable patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if at least one path is not in the skippable set."""
if paths is None:
return False
return any(not is_path_skippable(p) for p in paths)
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if CI workflows should run given a list of modified paths."""
if paths is None:
print("No files were modified, skipping TheRock CI jobs")
return False
paths_set = set(paths)
github_workflows_paths = set(
[p for p in paths if p.startswith(".github/workflows")]
)
other_paths = paths_set - github_workflows_paths
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
print("should_ci_run_given_modified_paths findings:")
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
if contains_other_non_skippable_files:
print("Enabling TheRock CI jobs since a non-skippable path was modified")
return True
else:
print(
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
)
return False
def main(args):
base_ref = args.get("base_ref")
modified_paths = get_modified_paths(base_ref)
print("modified_paths (max 200):", modified_paths[:200])
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
output = {
'enable_therock_ci': json.dumps(enable_jobs)
}
gha_set_output(output)
if __name__ == "__main__":
args = {}
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
main(args)

View File

@@ -21,9 +21,11 @@ jobs:
id-token: write
container:
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
options: -v /runner/config:/home/awsconfig/
env:
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
TEATIME_FORCE_INTERACTIVE: 0
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
steps:
- name: Checkout composable_kernel repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -83,9 +85,9 @@ jobs:
echo "----------"
du -h -d 1 TheRock/build/artifacts
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
- name: Configure AWS Credentials for non-forked repos
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external

View File

@@ -5,6 +5,15 @@ on:
branches:
- develop
workflow_dispatch:
pull_request:
types:
- opened
- synchronize
branches:
- mainline
- release/*
- release-staging/*
- develop
permissions:
contents: read
@@ -18,8 +27,29 @@ concurrency:
cancel-in-progress: true
jobs:
setup:
runs-on: ubuntu-24.04
env:
# The commit being checked out is the merge commit for a PR. Its first
# parent will be the tip of the base branch.
BASE_REF: HEAD^
outputs:
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
steps:
- name: "Checking out repository"
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# We need the parent commit to do a diff
fetch-depth: 2
- name: "Configuring CI options"
id: configure
run: python .github/scripts/therock_configure_ci.py
therock-ci-linux:
name: TheRock CI Linux
needs: setup
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
permissions:
contents: read
id-token: write
@@ -34,6 +64,7 @@ jobs:
name: TheRock CI Summary
if: always()
needs:
- setup
- therock-ci-linux
runs-on: ubuntu-24.04
steps:

View File

@@ -68,6 +68,7 @@ jobs:
VENV_DIR: ${{ env.VENV_DIR }}
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
PLATFORM: ${{ inputs.platform }}
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
- name: Test
timeout-minutes: ${{ matrix.components.timeout_minutes }}

View File

@@ -2,7 +2,7 @@
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## Composable Kernel 1.1.0 for ROCm 7.0.0
## Composable Kernel 1.2.0 for ROCm 7.0.0
### Added
@@ -27,6 +27,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added int8 support for CK_TILE GEMM.
* Added support for elementwise kernel.
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
### Optimized
@@ -48,6 +49,7 @@ None
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
### Known issues

View File

@@ -16,12 +16,21 @@ else()
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
endif()
# Allow user to specify the C++ standard.
# We must support C++17 builds until downstream users are migrated to C++20, but we default to C++20.
set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)")
set(valid_cxx_standards 17 20)
set_property(CACHE CK_CXX_STANDARD PROPERTY STRINGS ${valid_cxx_standards})
if(NOT CK_CXX_STANDARD IN_LIST valid_cxx_standards)
message(FATAL_ERROR "CK_CXX_STANDARD must be one of ${valid_cxx_standards}")
endif()
# Default installation path
if(NOT WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
endif()
set(version 1.1.0)
set(version 1.2.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
include(CTest)
@@ -221,11 +230,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
set(CK_TILE_USE_WMMA 1)
endif()
# define the macro with the current value (0 or 1)
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
@@ -324,32 +342,19 @@ if(USE_BITINT_EXTENSION_INT4)
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
if(USE_OPT_GFX11)
add_compile_options(-mcumode)
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
endif()
if(ENABLE_ASM_DUMP)
add_compile_options(--save-temps)
add_compile_options(-Wno-gnu-line-marker)
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
endif()
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
add_compile_options(-mno-wavefrontsize64)
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
endif()
## Threads
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
link_libraries(Threads::Threads)
## C++
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD ${CK_CXX_STANDARD})
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")

23
Dockerfile.pytorch Normal file
View File

@@ -0,0 +1,23 @@
ARG BASE_DOCKER="rocm/pytorch-nightly:latest"
FROM $BASE_DOCKER
ARG CK_PYTORCH_BRANCH="develop"
RUN groupadd -g 109 render && \
usermod -u 1001 jenkins && \
groupmod -g 1001 jenkins && \
cd /tmp/pytorch && \
rm -rf build && \
cd /tmp/pytorch/third_party && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/aiter/3rdparty && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/fbgemm/external && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
cd /tmp/pytorch/third_party/flash-attention/csrc && \
rm -rf composable_kernel && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
chown -R jenkins:jenkins /tmp/pytorch && \
chmod -R a+rwx /tmp/pytorch && \
sudo usermod -aG irc jenkins

237
Jenkinsfile vendored
View File

@@ -192,12 +192,16 @@ def buildDocker(install_prefix){
image_name = "rocm/composable_kernel:ck_aiter"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
}
else{
else if(params.RUN_PYTORCH_TESTS){
image_name = "rocm/composable_kernel:ck_pytorch"
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
}
else{
dockerArgs = dockerArgs + " -f Dockerfile . "
}
echo "Build Args: ${dockerArgs}"
try{
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS){
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS || params.RUN_PYTORCH_TESTS){
//force building the new docker if that parameter is true
echo "Building image: ${image_name}"
retimage = docker.build("${image_name}", dockerArgs)
@@ -400,8 +404,9 @@ def cmake_build(Map conf=[:]){
echo "Build packages"
sh 'ninja -j64 package'
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
stash includes: "composablekernel-**.deb", name: "packages"
}
}
else{
@@ -571,50 +576,66 @@ def Build_CK(Map conf=[:]){
python3 -m pytest python/test/test_gen_instances.py
"""
}
dir("build"){
if (params.RUN_FULL_QA && arch == 2 ){
// build deb packages
echo "Build packages"
sh 'ninja package'
archiveArtifacts artifacts: 'composablekernel*.deb'
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb'
sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb'
sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb'
stash includes: "composablekernel-**.deb", name: "packages"
}
}
// run performance tests, stash the logs, results will be processed on the master node
dir("script"){
if (params.RUN_PERFORMANCE_TESTS){
if (params.RUN_FULL_QA && arch == 1){
// run full tests on gfx90a
echo "Run full performance tests"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
archiveArtifacts "perf_batched_gemm.log"
archiveArtifacts "perf_grouped_gemm.log"
archiveArtifacts "perf_grouped_conv_fwd.log"
archiveArtifacts "perf_grouped_conv_bwd_data.log"
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
archiveArtifacts "perf_gemm_bilinear.log"
archiveArtifacts "perf_reduction.log"
archiveArtifacts "perf_splitK_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_mixed_gemm.log"
stash includes: "perf_**.log", name: "perf_log"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
archiveArtifacts "perf_gemm_gfx90a.log"
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
archiveArtifacts "perf_batched_gemm_gfx90a.log"
archiveArtifacts "perf_grouped_gemm_gfx90a.log"
archiveArtifacts "perf_grouped_conv_fwd_gfx90a.log"
archiveArtifacts "perf_grouped_conv_bwd_data_gfx90a.log"
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx90a.log"
archiveArtifacts "perf_gemm_bilinear_gfx90a.log"
archiveArtifacts "perf_reduction_gfx90a.log"
archiveArtifacts "perf_splitK_gemm_gfx90a.log"
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
archiveArtifacts "perf_mixed_gemm_gfx90a.log"
stash includes: "perf_**.log", name: "perf_log_gfx90a"
}
if (params.RUN_FULL_QA && arch == 2){
// run full tests on gfx942
echo "Run full performance tests"
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
archiveArtifacts "perf_gemm_gfx942.log"
archiveArtifacts "perf_resnet50_N256_gfx942.log"
archiveArtifacts "perf_resnet50_N4_gfx942.log"
archiveArtifacts "perf_batched_gemm_gfx942.log"
archiveArtifacts "perf_grouped_gemm_gfx942.log"
archiveArtifacts "perf_grouped_conv_fwd_gfx942.log"
archiveArtifacts "perf_grouped_conv_bwd_data_gfx942.log"
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx942.log"
archiveArtifacts "perf_gemm_bilinear_gfx942.log"
archiveArtifacts "perf_reduction_gfx942.log"
archiveArtifacts "perf_splitK_gemm_gfx942.log"
archiveArtifacts "perf_onnx_gemm_gfx942.log"
archiveArtifacts "perf_mixed_gemm_gfx942.log"
stash includes: "perf_**.log", name: "perf_log_gfx942"
}
else if ( arch == 1 ){
// run standard tests on gfx90a
echo "Run performance tests"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
archiveArtifacts "perf_gemm.log"
archiveArtifacts "perf_onnx_gemm.log"
archiveArtifacts "perf_resnet50_N256.log"
archiveArtifacts "perf_resnet50_N4.log"
stash includes: "perf_**.log", name: "perf_log"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
archiveArtifacts "perf_gemm_gfx90a.log"
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
stash includes: "perf_**.log", name: "perf_log_gfx90a"
}
else if ( arch == 2 ){
// run standard tests on gfx942
echo "Run performance tests"
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
archiveArtifacts "perf_gemm_gfx942.log"
archiveArtifacts "perf_onnx_gemm_gfx942.log"
archiveArtifacts "perf_resnet50_N256_gfx942.log"
archiveArtifacts "perf_resnet50_N4_gfx942.log"
stash includes: "perf_**.log", name: "perf_log_gfx942"
}
// disable performance tests on gfx1030 for now.
//else if ( arch == 3){
@@ -732,29 +753,64 @@ def process_results(Map conf=[:]){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
unstash "perf_fmha_log_gfx942"
}
catch(Exception err){
echo "could not locate the FMHA performance logs for gfx942: ${err.getMessage()}."
}
try{
unstash "perf_fmha_log_gfx90a"
}
catch(Exception err){
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
}
}
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
if (params.BUILD_INSTANCES_ONLY){
// unstash deb packages
unstash "packages"
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
}
else{
// unstash perf files to master
unstash "perf_log"
try{
unstash "perf_log_gfx90a"
}
catch(Exception err){
echo "could not locate the gfx90a performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx942"
}
catch(Exception err){
echo "could not locate the gfx942 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx950"
}
catch(Exception err){
echo "could not locate the gfx950 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx908"
}
catch(Exception err){
echo "could not locate the gfx908 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx11"
}
catch(Exception err){
echo "could not locate the gfx11 performance logs: ${err.getMessage()}."
}
try{
unstash "perf_log_gfx12"
}
catch(Exception err){
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
echo "could not locate the gfx12 performance logs: ${err.getMessage()}."
}
sh "./process_perf_data.sh"
}
// process the logs
sh "./process_perf_data.sh"
}
}
catch(e){
@@ -819,13 +875,64 @@ def run_aiter_tests(Map conf=[:]){
}
}
def run_pytorch_tests(Map conf=[:]){
show_node_info()
env.HSA_ENABLE_SDMA=0
checkout scm
//use the latest pytorch-nightly image
def image = "rocm/composable_kernel:ck_pytorch"
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
def variant = env.STAGE_NAME
def retimage
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
echo "Docker flags: ${dockerOpts}"
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try
{
echo "Pulling image: ${image}"
retimage = docker.image("${image}")
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
retimage.pull()
}
}
catch(Exception ex)
{
error "Unable to locate image: ${image}"
}
}
withDockerContainer(image: image, args: dockerOpts) {
timeout(time: 45, unit: 'MINUTES'){
try{
sh "rocminfo"
sh "python3 --version"
sh "python3 /tmp/pytorch/tools/amd_build/build_amd.py"
sh "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
}
catch(e){
echo "Throwing error exception while building Pytorch"
echo 'Exception occurred: ' + e.toString()
throw e
}
finally{
echo "Finished building Pytorch"
}
}
}
}
//launch develop branch daily jobs
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false
0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : ""
pipeline {
agent none
@@ -960,6 +1067,14 @@ pipeline {
name: "RUN_ALL_UNIT_TESTS",
defaultValue: false,
description: "Run all unit tests (default: OFF)")
booleanParam(
name: "RUN_PYTORCH_TESTS",
defaultValue: false,
description: "Try building PYTORCH with latest CK develop branch (default: OFF)")
string(
name: 'ck_pytorch_branch',
defaultValue: 'develop',
description: 'Specify which branch of CK to test with Pytorch (default: develop)')
booleanParam(
name: "RUN_AITER_TESTS",
defaultValue: false,
@@ -1051,6 +1166,24 @@ pipeline {
}
}
}
}
stage("Run Pytorch Tests")
{
parallel
{
stage("Run Pytorch Tests on gfx942")
{
when {
beforeAgent true
expression { params.RUN_PYTORCH_TESTS.toBoolean() }
}
agent{ label rocmnode("gfx942")}
steps{
run_pytorch_tests()
cleanWs()
}
}
}
}
stage("Run AITER Tests")
{
@@ -1107,11 +1240,16 @@ pipeline {
agent{ label rocmnode("gfx90a")}
environment{
setup_args = "NO_CK_BUILD"
execute_args = """ cd test_data && \
./generate_test_dataset.sh && \
cd ../script && \
execute_args = """ cd ../build && \
../script/cmake-ck-dev.sh ../ gfx90a && \
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
cd ../test_data && \
# Dataset generation modes:
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
./generate_test_dataset.sh half && \
cd ../build && \
./bin/test_grouped_convnd_fwd_dataset_xdl"""
}
steps{
@@ -1306,6 +1444,7 @@ pipeline {
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
setup_args = """ -DGPU_TARGETS="gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " \
-DCK_CXX_STANDARD="17" \
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
execute_args = " "
}
@@ -1440,7 +1579,7 @@ pipeline {
-D CMAKE_BUILD_TYPE=Release \
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
}
cleanWs()
}
@@ -1517,7 +1656,7 @@ pipeline {
stage("Process results"){
when {
beforeAgent true
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
}
agent { label 'mici' }
steps{

View File

@@ -19,7 +19,6 @@ Getting started
build the library. You can also find some of this information in the
`README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
on the project's GitHub page.
#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities.
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities.
#. **General information:** For broader information about AMD products, consider exploring the

View File

@@ -1,7 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/validation_common.hpp"
template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
@@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
try
{
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
M, N, K, StrideA, StrideB, StrideC);
}
catch(const std::runtime_error& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return false;
}
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -7,7 +7,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fmha_fwd -j
```
This will result in an executable `build/bin/tile_example_fmha_fwd`

View File

@@ -110,9 +110,9 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
@@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_co
{{
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}
@@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_confi
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
template <>
@@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::strea
{{
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
ck_tile::stream_config{{s.stream_id_}});
}}

View File

@@ -110,9 +110,9 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
@@ -385,7 +385,7 @@ class FmhaFwdApiPool:
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][(hdim, hdim_v)]
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'

View File

@@ -60,9 +60,9 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}
@@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
{{
using k_ = fmha_kernel;
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
}}
}};
}}

View File

@@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
fmha_traits.hdim_q,
fmha_traits.hdim_v,
fmha_traits.data_type.c_str(),
fmha_traits.is_group_mode,
static_cast<int>(fmha_traits.mask_type),
static_cast<int>(fmha_traits.bias_type),
fmha_traits.has_dbias,
fmha_traits.has_dropout,
fmha_traits.is_store_randval,
fmha_traits.is_deterministic);
fflush(stdout);
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());

View File

@@ -9,6 +9,8 @@
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
set -euo pipefail
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type

View File

@@ -1,5 +1,7 @@
#!/bin/sh
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
@@ -17,12 +19,12 @@ for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done

View File

@@ -1,5 +1,7 @@
#!/bin/bash
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
@@ -51,19 +53,18 @@ run_fp16_bf16_tests() {
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {

View File

@@ -42,7 +42,7 @@ return hidden_states, per_token_scale
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`

View File

@@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a)
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -7,7 +7,7 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation

View File

@@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
@@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -65,7 +75,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -81,8 +90,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -100,10 +109,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
<< std::endl;
}
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
@@ -208,7 +208,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -232,7 +231,7 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -279,15 +278,13 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -373,7 +370,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
float ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,

4
example/ck_tile/03_gemm/gemm_utils.hpp Executable file → Normal file
View File

@@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#if CK_TILE_USE_WMMA
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
#endif
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
@@ -484,7 +486,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -126,7 +125,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -172,15 +171,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};

View File

@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -127,7 +126,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -173,15 +172,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -338,7 +335,11 @@ int main(int argc, char* argv[])
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -7,7 +7,7 @@ This folder contains example for Image to Column using ck_tile tile-programming
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`

View File

@@ -55,13 +55,12 @@ float image_to_column(const image_to_column_traits& traits,
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
args.G);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
float ave_time = ck_tile::launch_kernel(
stream_conf,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
stream_conf, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -94,18 +94,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
throw std::runtime_error("Wrong! Arguments not supported!\n");
}
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
input_shape,
input_strides,
kept_dim,
reduce_dims));
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;

View File

@@ -15,7 +15,7 @@ args:
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_permute -j
```
This will result in an executable `build/bin/tile_example_permute`

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
ck_tile::kentry<1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
static constexpr int kBlockSize = BLOCK_SIZE;
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;

View File

@@ -53,11 +53,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
@@ -69,11 +69,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
@@ -85,11 +85,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(a);
const dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}

View File

@@ -6,7 +6,7 @@ This folder contains example for topk-softmax kernel using ck_tile tile-programm
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`

View File

@@ -13,11 +13,11 @@
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;

View File

@@ -6,7 +6,7 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_rmsnorm2d_fwd -j
```
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`

View File

@@ -138,12 +138,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -249,7 +249,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -257,7 +257,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
}}
"""

View File

@@ -6,7 +6,7 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j
```
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`

View File

@@ -136,12 +136,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -6,7 +6,7 @@ This folder contains example for smoothquant using ck_tile tile-programming impl
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_smoothquant -j
```
This will result in an executable `build/bin/tile_smoothquant`

View File

@@ -126,12 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;

View File

@@ -50,7 +50,7 @@ float smoothquant_(const S& s, A a)
using Kernel = ck_tile::Smoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -58,5 +58,5 @@ float smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -6,7 +6,7 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`

View File

@@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
@@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)

View File

@@ -9,7 +9,7 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`

View File

@@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a)
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
@@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -53,7 +53,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
const dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
static int printed = 0;
@@ -66,5 +66,5 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -213,7 +213,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -231,7 +231,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -287,7 +287,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_size = kernel::GetSmemSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \

View File

@@ -7,7 +7,7 @@ This folder contains example for batched GEMM using ck_tile tile-programming imp
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable `build/bin/tile_example_batched_gemm`

View File

@@ -142,7 +142,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -156,8 +155,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -176,7 +175,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -148,7 +148,7 @@ All the necessary parameters are set, the tiling is computed, the GEMM pipeline
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
```

View File

@@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
void* kargs_ptr,
bool splitk)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
@@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -53,8 +48,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
float ave_time{0};
@@ -82,7 +75,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -92,9 +84,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
@@ -105,7 +97,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,

View File

@@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the flatmm calculation
make tile_example_flatmm_basic -j
```

View File

@@ -101,7 +101,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
@@ -119,8 +118,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -171,15 +170,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};

View File

@@ -40,9 +40,11 @@ template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
@@ -213,6 +215,16 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 3)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 4)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
}
else
{
a_host.SetZero();

View File

@@ -135,9 +135,9 @@ struct batched_forward_causal_local_bias_dropout_dispatch
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockSize.x, kBlockPerCu>(
HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
};

View File

@@ -126,9 +126,9 @@ struct jagged_forward_causal_local_bias_dropout_dispatch
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
(void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockSize.x, kBlockPerCu>(
HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockPerCu>(HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
};
};

View File

@@ -8,7 +8,7 @@ This folder contains example for Multiple D GEMM using ck_tile tile-programming
mkdir build && cd build
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
#The basic pipeline method on the gemm calculation
make tile_example_gemm_multi_d_fp16 -j
```

View File

@@ -146,7 +146,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
DsLayout,
CLayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -160,8 +159,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -176,7 +175,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -6,3 +6,6 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,215 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
}
#include "run_grouped_convolution_bwd_data_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
int run_grouped_conv_bwd_data_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
{
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation!");
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); }

View File

@@ -78,7 +78,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -98,8 +97,8 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -123,7 +122,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
float ave_time = ck_tile::launch_kernel_time_mask(
s,
Kernel::Preprocess(kargs, s),
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -77,7 +77,6 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -97,8 +96,8 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -120,7 +119,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};

View File

@@ -0,0 +1,186 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_warmup,
int n_repeat)
{
float ave_time = grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_bwd_data_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<WeiDataType>{-1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
}
else
{
weight.SetZero();
output.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.SetZero();
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
input_dev_buf.FromDevice(input.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
input_host_ref.SetZero();
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(input,
input_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
throw std::runtime_error("Unsupported gpu verification !!!");
}
return pass;
}

View File

@@ -167,17 +167,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// Run the kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,

View File

@@ -112,17 +112,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -99,17 +99,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;

View File

@@ -6,7 +6,7 @@ This folder contains example for batched Transpose using ck_tile tile-programmin
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_batched_transpose -j
```

View File

@@ -74,8 +74,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
const dim3 grids = kernel::GridSize(a);
const dim3 blocks = kernel::BlockSize();
printf("Pipeline: %d\n", Config::kPipelineId);
printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z);
@@ -96,8 +96,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
printf("Launching Kernel...\n");
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
float ave_time =
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs));
printf("Kernel finished...\n");

View File

@@ -8,9 +8,8 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -7,9 +7,10 @@ This folder contains example for Block Scale GEMM using ck_tile tile-programming
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# The aquant pipeline method on the gemm calculation
make tile_example_gemm_aquant_basic -j
make tile_example_gemm_bquant_basic -j
```
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,29 +20,26 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +48,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -68,7 +69,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
constexpr bool transposed_warp_gemm = true;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
@@ -82,6 +83,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
@@ -96,7 +98,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -111,8 +112,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
@@ -136,7 +137,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
@@ -187,13 +188,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -201,32 +203,18 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
@@ -235,4 +223,4 @@ int run_gemm_example(int argc, char* argv[])
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -8,11 +8,10 @@
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -21,8 +20,7 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = 16;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 1;
constexpr ck_tile::index_t N_Warp = 4;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 32;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits =
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
@@ -82,6 +85,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
@@ -96,7 +100,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
@@ -111,8 +114,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
@@ -136,7 +139,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
@@ -187,13 +190,14 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
@@ -201,7 +205,7 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
float,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
@@ -210,29 +214,18 @@ int run_gemm_example(int argc, char* argv[])
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
float,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
}

View File

@@ -0,0 +1,228 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
kPadN,
kPadK,
GemmConfig::PreshuffleQuant,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
BQDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel =
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
;
}
#include "run_gemm_bquant_example.inc"
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for B.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }

View File

@@ -11,11 +11,9 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -83,200 +81,37 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
struct GemmConfigDecode : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};
template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};
template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -288,18 +123,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
struct GemmConfigPreshufle_AQ : public GemmConfigBase
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
@@ -314,9 +146,11 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
@@ -332,176 +166,6 @@ struct GemmQuantTypeConfig
using CDataType = CDataType_;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using QDataType = float;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
{
using ADataType = ck_tile::pk_int4_t;
using QDataType = float;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using QDataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using QDataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::fp8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
{
using ADataType = ck_tile::bf8_t;
using QDataType = float;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
@@ -559,55 +223,6 @@ struct DataTypeTraits<ck_tile::int8_t>
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -1,3 +1,4 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -31,7 +32,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
@@ -40,8 +42,7 @@ template <typename ADataType,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
bool Preshuffle = false>
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
@@ -73,7 +74,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<ADataType,
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -82,8 +84,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ALayout,
BLayout,
CLayout,
QuantGroupSize,
Preshuffle>(
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
@@ -206,7 +207,7 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(GemmConfig::Preshuffle)
if constexpr(GemmConfig::PreshuffleQuant)
{
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
@@ -222,7 +223,8 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
@@ -231,22 +233,21 @@ int run_gemm_example_with_layouts(int argc,
AQLayout,
BLayout,
CLayout,
QuantGroupSize,
GemmConfig::Preshuffle>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;

View File

@@ -0,0 +1,286 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <bit>
#include <random>
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename T>
auto shuffle_bq(const ck_tile::HostTensor<T>& t, int block_bq_k)
{
if(t.get_lengths().size() != 2)
{
throw std::runtime_error("Host tensor is not rank 2 tensor.");
}
int n_ = t.get_lengths()[0];
int bqk_ = t.get_lengths()[1];
if(bqk_ % block_bq_k != 0)
{
throw std::runtime_error("shuffle_aq needs a bqk of multiple times of block_bq_k.");
}
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {1, 0, 2});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename BQDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename BQLayout,
typename DsLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& bq_bqk_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t BQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_BQ,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::BQuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK = BQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_BQ = stride_BQ;
float ave_time = gemm_calc_bquant<GemmConfig,
ADataType,
BDataType,
BQDataType,
AccDataType,
CDataType,
ADataType, // computeDatatype
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K +
sizeof(BQDataType) * BQK * N + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideBQ =" << stride_BQ
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " BQ_Type = " << DataTypeTraits<BQDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename BLayout,
typename BQLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const BQLayout bq_layout = BQLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using BQDataType = typename TypeConfig::QDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t BQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<BQDataType> bq_bqk_n(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(bq_bqk_n);
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return 0;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(bq_bqk_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
bq_bqk_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
BDataType,
BQDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
BQLayout,
ck_tile::tuple<>,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
b_k_n_dev_buf,
bq_bqk_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
BQK,
stride_A,
stride_B,
stride_BQ,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}

View File

@@ -12,7 +12,7 @@ This experimental kernel is intended for novice CK developers. It introduces the
mkdir build && cd build
# you can replace <arch> with the appropriate architecture
# (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
../script/cmake-ck-dev.sh ../ <arch>
# Make the copy kernel executable
make tile_example_copy -j
```

View File

@@ -77,10 +77,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
// we intentionally do not use pipeline for this example and let the kernel be composite of
// Problem and Policy
constexpr ck_tile::index_t kBlockSize = Shape::BlockSize;
auto blockSize = Kernel::BlockSize();
// Print configuration information
std::cout << "block size (number of threads per block) " << kBlockSize << std::endl;
std::cout << "block size (number of threads per block) " << blockSize << std::endl;
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
@@ -99,16 +99,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< ")" << std::endl;
// Launch kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<kBlockSize, 1>(Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
float ave_time =
launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
ck_tile::make_kernel<1>(Kernel{},
kGridSize,
blockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n));
// Calculate and print performance metrics
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;

View File

@@ -27,8 +27,9 @@ struct TileCopyShape
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
// Wave tile dimensions
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
// Block tile dimensions
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
@@ -45,7 +46,6 @@ struct TileCopyShape
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
// Hardware configuration
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
// Configuration validation
@@ -60,8 +60,10 @@ struct TileCopyShape
"Invalid wave configuration for N dimension");
// Ensure wave tile dimensions align with wave size
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
#endif
};
/**
@@ -200,6 +202,19 @@ struct ElementWiseTileCopyKernel
using XDataType = typename Problem::XDataType;
using Policy = ck_tile::remove_cvref_t<Policy_>;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static auto BlockSize()
{
if(ck_tile::is_wave32())
{
return kBlockSize / 2;
}
else
{
return kBlockSize;
}
}
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
{
using S = typename Problem::BlockShape;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -52,10 +52,27 @@ inline std::string get_device_name()
}
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
inline bool is_gfx11_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
ck::get_device_name() == "gfx1152";
}
inline bool is_xdl_supported()
{
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|| is_gfx12_supported() || is_gfx11_supported()
#endif
;
}
inline bool is_lds_direct_load_supported()
@@ -67,7 +84,8 @@ inline bool is_lds_direct_load_supported()
inline bool is_bf16_atomic_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
is_gfx12_supported();
}
inline bool is_gfx101_supported()
@@ -83,18 +101,5 @@ inline bool is_gfx103_supported()
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
}
inline bool is_gfx11_supported()
{
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
ck::get_device_name() == "gfx1152";
}
inline bool is_gfx12_supported()
{
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
}
} // namespace ck
#endif

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdexcept>
#include <string>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
namespace ck {
namespace utils {
template <typename Layout>
inline void
validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride")
{
if(ck::is_same_v<Layout, ck::tensor_layout::gemm::ColumnMajor>)
{
if(stride < M)
{
throw std::runtime_error(
"Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) +
") must be greater than or equal to dim (" + std::to_string(M) + ")");
}
}
else // RowMajor
{
if(stride < N)
{
throw std::runtime_error(
"Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) +
") must be greater than or equal to dim (" + std::to_string(N) + ")");
}
}
}
// Convenience functions for common GEMM patterns
template <typename ALayout, typename BLayout, typename CLayout>
inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
{
validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
validate_gemm_stride<CLayout>(M, N, StrideC, "StrideC");
}
} // namespace utils
} // namespace ck

View File

@@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
static constexpr index_t WaveSize = 64;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
@@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
@@ -219,6 +218,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
@@ -227,6 +227,7 @@ struct BlockwiseGemmXdlops_pipeline_base
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
"wrong!");
#endif
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::WaveSize;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
@@ -626,13 +627,14 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t WgpPerCU =
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -159,6 +159,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
__device__ static constexpr auto HotLoopScheduler()
{
#if !defined(__gfx11__) && !defined(__gfx12__)
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
@@ -260,6 +261,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
#endif
}
template <bool HasMainLoop,

View File

@@ -176,8 +176,36 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BElementwiseOperation,
CElementwiseOperation>
{
template <bool isWave64>
static constexpr auto GetNXdlPerWave()
{
constexpr index_t Waves = isWave64 ? BlockSize / 64 : BlockSize / 32;
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL);
static_assert(MWaves > 0);
constexpr index_t NWaves = Waves / MWaves;
if constexpr(NWaves == 0)
{
return 0;
}
else
{
if constexpr(NPerBlock % (NPerXDL * NWaves) == 0)
{
return NPerBlock / (NWaves * NPerXDL);
}
else
{
return 0;
}
}
}
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3<
static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
@@ -199,7 +227,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
NXdlPerWave_,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
@@ -226,8 +254,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
ComputeTypeB,
PermuteA,
PermuteB>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
using Argument = typename GridwiseGemm::Argument;
using Argument = typename GridwiseGemm64::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
@@ -254,12 +284,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
template <typename GridwiseGemm>
float RunImp(const typename GridwiseGemm::Argument& arg,
const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
@@ -285,7 +312,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
auto arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
@@ -297,7 +324,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
ck::utility::RotatingMemWrapper<typename GridwiseGemm::Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
@@ -733,6 +760,31 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return ave_time;
}
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
return RunImp<GridwiseGemm64>(arg, stream_config);
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
return RunImp<GridwiseGemm32>(
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg),
stream_config);
}
}
return 0;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
@@ -754,9 +806,39 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
if(arg.KBatch > 1)
{
return false;
if(is_gfx11_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t>)
{
return false;
}
if(sizeof(CDataType) == 1)
{
return false;
}
}
if(is_gfx11_supported() || is_gfx12_supported())
{
if(MPerXDL != 16 || NPerXDL != 16)
{
return false;
}
}
if(is_gfx11_supported())
{
if constexpr(std::is_same_v<ADataType, ck::f8_t> ||
std::is_same_v<ADataType, ck::bf8_t>)
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
@@ -767,7 +849,29 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return false;
}
return GridwiseGemm::CheckValidity(arg);
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
return GridwiseGemm64::CheckValidity(arg);
}
else
{
return false;
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
return GridwiseGemm32::CheckValidity(
reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
}
else
{
return false;
}
}
}
// polymorphic
@@ -849,6 +953,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
index_t PrefetchStages = 0;
index_t AMmaKStride = 0;
if(get_warp_size() == 64)
{
if constexpr(NXdlPerWave64 > 0)
{
PrefetchStages = GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
AMmaKStride = GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
}
}
else
{
if constexpr(NXdlPerWave32 > 0)
{
PrefetchStages = GridwiseGemm32::BlockwiseGemmPipe::PrefetchStages;
AMmaKStride = GridwiseGemm32::BlockwiseGemmPipe::AMmaKStride;
}
}
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
@@ -872,9 +995,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
<< AMmaKStride;
// clang-format on
return str.str();

View File

@@ -641,7 +641,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / KPerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{

View File

@@ -506,7 +506,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
// Ensure that k_batch_ does not exceed the maximum value
// for the GEMM pipeline.
const auto k_batch_max = static_cast<index_t>((gemmK - 1) / K0PerBlock);
k_batch_ = std::min(k_batch_, k_batch_max);
k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1);
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{

View File

@@ -35,19 +35,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 on pre-gfx950
if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) &&
static_cast<bool>(Arch::is_gfx950_build)) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32))
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -77,22 +66,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
{
#if defined(__gfx9__)
enum struct Arch : bool
#if defined(__gfx9__) || defined(__gfx12__) || defined(__gfx11__)
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 on pre-gfx950
if constexpr(((GridwiseGemm::AK1Number >= 32 || GridwiseGemm::BK1Number >= 32) &&
static_cast<bool>(Arch::is_gfx950_build)) ||
(GridwiseGemm::AK1Number < 32 && GridwiseGemm::BK1Number < 32))
{
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
@@ -694,12 +672,23 @@ struct GridwiseGemm_xdl_cshuffle_v3
__host__ void Print() const
{
std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
// clang-format off
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
// clang-format off
}
index_t M;
@@ -829,6 +818,10 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
@@ -886,7 +879,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / MPerXdl;
constexpr auto KThreadRead = WaveSize / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
@@ -967,6 +960,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = (NXdlPerWave * NPerXdl == 0) ? 0 : NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = (MWaves * NWaves == 0) ? 64 : BlockSize / (MWaves * NWaves);
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
@@ -1020,7 +1016,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto KThreadRead = WaveSize / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
@@ -1167,12 +1163,99 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_block_size * sizeof(CShuffleDataType));
}
template <InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__device__ static bool constexpr IsValidCompilationParameter()
{
enum struct Arch : bool
{
#if defined(__gfx950__)
is_gfx950_build = true,
#else
is_gfx950_build = false,
#endif
};
// skip building the instances with K1>=32 && PackedSize != 2 on pre-gfx950
if constexpr(static_cast<bool>(Arch::is_gfx950_build) ||
(AK1Number < 32 && BK1Number < 32) ||
(AK1Number >= 32 && APackedSize == 2) ||
(BK1Number >= 32 && BPackedSize == 2))
{
}
else
{
return false;
}
// Check tile size
#if defined(__gfx11__) || defined(__gfx12__)
if constexpr(MPerXdl != 16 || NPerXdl != 16)
{
return false;
}
#endif
// Check atomic caps
#if defined(__gfx11__)
constexpr bool SupportMemOp = CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set;
#else
constexpr bool SupportMemOp = sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation ==
InMemoryDataOperationEnum::Set);
#endif
if constexpr(SupportMemOp == false)
{
return false;
}
// Check tile size
if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
{
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
if constexpr(MWaves > 0 && NWaves > 0)
{
constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
if constexpr(WaveSize == get_warp_size())
{
return true;
}
else
{
return false;
}
}
else
{
return false;
}
}
else
{
return false;
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__ static constexpr bool CheckValidity(const Argument& karg)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr((MPerXdl * MXdlPerWave) == 0 || (NXdlPerWave * NPerXdl) == 0)
{
return false;
}
else
{
if constexpr((MPerBlock % (MPerXdl * MXdlPerWave) != 0) ||
(NPerBlock % (NXdlPerWave * NPerXdl) != 0))
{
return false;
}
else
{
if(BlockwiseGemmPipe::WaveSize != get_warp_size())
{
return false;
}
}
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||

View File

@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_xdlops.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace ck {
/**
@@ -76,7 +77,21 @@ enum struct MfmaInstr
mfma_f32_32x32x64f8f6f4,
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4
mfma_scale_f32_16x16x128f8f6f4,
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
wmma_i32_16x16x16_iu8,
wmma_unsupport_16x16_gfx11,
// gfx12
wmma_f32_16x16x16_f16_gfx12,
wmma_f32_16x16x16_bf16_gfx12,
wmma_i32_16x16x16_iu8_gfx12,
wmma_f32_16x16x16_f8f8_gfx12,
wmma_f32_16x16x16_f8bf8_gfx12,
wmma_f32_16x16x16_bf8f8_gfx12,
wmma_f32_16x16x16_bf8bf8_gfx12,
wmma_unsupport_16x16_gfx12,
};
template <MfmaInstr instr>
@@ -932,6 +947,175 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
}
};
// gfx11
struct mfma_type_gfx11_base
{
static constexpr index_t group_size = 8;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 8;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 32;
static constexpr index_t num_input_blks = 1;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 16;
static constexpr bool is_k_reduction = true;
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16> : public mfma_type_gfx11_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_f16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16> : public mfma_type_gfx11_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_bf16_w32<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8> : public mfma_type_gfx11_base
{
template <index_t MPerWmma,
index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_i32_16x16x16_iu8_w32<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx11> : public mfma_type_gfx11_base
{
static constexpr index_t k_per_blk = 2;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
{
// empty for all unsupported types.
}
};
// gfx12
struct mfma_type_gfx12_base
{
static constexpr index_t group_size = 8;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 8;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 32;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f16_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf16_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_i32_16x16x16_iu8_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma,
index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12<MPerWmma, NPerWmma, neg_a, neg_b, clamp>::Run(
a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12> : public mfma_type_gfx12_base
{
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::wmma_unsupport_16x16_gfx12> : public mfma_type_gfx12_base
{
static constexpr index_t k_per_blk = 2;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA&, const FloatB&, FloatC&) const
{
// empty for all unsupported types.
}
};
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
@@ -951,7 +1135,13 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<double, 16, 16>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f64_16x16x4f64;
#endif
}
template <>
@@ -993,7 +1183,13 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<float, 16, 16>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x4xf32;
#endif
}
template <>
@@ -1026,7 +1222,11 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_f32_16x16x16_f16;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32f16;
#else
return MfmaInstr::mfma_f32_16x16x16f16;
@@ -1036,7 +1236,13 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_f32_16x16x16_f16;
#else
return MfmaInstr::mfma_f32_16x16x16f16;
#endif
}
template <>
@@ -1082,7 +1288,11 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_f32_16x16x16_bf16;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
@@ -1094,7 +1304,11 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_f32_16x16x16_bf16;
#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
#else
return MfmaInstr::mfma_f32_16x16x8bf16;
@@ -1126,7 +1340,11 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_i32_16x16x16_iu8;
#elif defined(__gfx950__)
return MfmaInstr::mfma_i32_16x16x64i8;
#elif defined(__gfx942__)
return MfmaInstr::mfma_i32_16x16x32i8;
@@ -1138,7 +1356,11 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
{
#if defined(__gfx942__) || defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_i32_16x16x16_iu8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_i32_16x16x16_iu8;
#elif defined(__gfx942__) || defined(__gfx950__)
return MfmaInstr::mfma_i32_16x16x32i8;
#else
return MfmaInstr::mfma_i32_16x16x16i8;
@@ -1186,13 +1408,23 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x32f8f8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32f8f8;
@@ -1263,13 +1495,23 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
#endif
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
@@ -1295,13 +1537,23 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x32f8bf8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32f8bf8;
@@ -1327,13 +1579,23 @@ struct MfmaSelector
template <>
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#else
return MfmaInstr::mfma_f32_16x16x32bf8f8;
#endif
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
{
#if defined(__gfx950__)
#if defined(__gfx12__)
return MfmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32bf8f8;
@@ -1355,10 +1617,18 @@ struct MfmaSelector
static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
"n_per_blk != num_threads_per_blk");
#if defined(__gfx11__)
if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
{
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
selected_mfma.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
}
#else
static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
selected_mfma.m_per_blk,
"m_per_blk != num_input_blks * num_regs_per_blk");
#endif
static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
selected_mfma.num_output_blks == 1,
@@ -1424,8 +1694,9 @@ struct XdlopsGemm
static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
#if defined(__HIP_DEVICE_COMPILE__)
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
#endif
}
// XDL output supporting C = A * B
@@ -1434,10 +1705,11 @@ struct XdlopsGemm
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
@@ -1446,7 +1718,7 @@ struct XdlopsGemm
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<num_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
@@ -1469,12 +1741,13 @@ struct XdlopsGemm
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
@@ -1485,7 +1758,7 @@ struct XdlopsGemm
make_pass_through_transform(M2),
make_pass_through_transform(N2),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<num_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
@@ -1512,10 +1785,11 @@ struct XdlopsGemm
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
@@ -1525,7 +1799,7 @@ struct XdlopsGemm
make_pass_through_transform(N1),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{}),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<num_blks>{},
Number<mfma_instr.group_size>{}))),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -1545,11 +1819,12 @@ struct XdlopsGemm
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
{
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
return transform_tensor_descriptor(
c_desc_g_m0_n0_m1_n1_m2_n2,
@@ -1558,9 +1833,8 @@ struct XdlopsGemm
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size)),
make_unmerge_transform(make_tuple(
mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
make_tuple(Sequence<0>{},
Sequence<1>{},
@@ -1642,8 +1916,32 @@ struct XdlopsGemm
__device__ static auto GetBlkIdx()
{
const auto laneId = GetLaneId();
const auto laneId = GetLaneId();
constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(
make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto blk_idx =
threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
const auto blk_id = blk_idx[I1];
const auto blk_td = blk_idx[I2];
return make_tuple(blk_id, blk_td);
}
template <bool SwizzleA>
__device__ static auto GetGfx11InputBlkIdx()
{
auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
if constexpr(SwizzleA)
{
laneId = ((laneId & 1) << 3) | (laneId >> 1);
}
constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(
make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
@@ -1661,8 +1959,12 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateAThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto laneId = GetLaneId();
#if defined(__gfx11__)
const auto blk_idx = GetGfx11InputBlkIdx<true>();
#else
const auto blk_idx = GetBlkIdx();
#endif
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];
@@ -1679,8 +1981,12 @@ struct XdlopsGemm
__host__ __device__ static auto CalculateBThreadOriginDataIndex()
{
const auto laneId = GetLaneId();
const auto laneId = GetLaneId();
#if defined(__gfx11__)
const auto blk_idx = GetGfx11InputBlkIdx<false>();
#else
const auto blk_idx = GetBlkIdx();
#endif
const auto blk_id = blk_idx[I0];
const auto blk_td = blk_idx[I1];

View File

@@ -75,9 +75,9 @@ template <index_t BlockSize,
bool IsF4F6 = false>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t WaveSize = BlockSize / WaveNumM / WaveNumN;
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -7,6 +7,38 @@
namespace ck {
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
__device__ constexpr index_t get_warp_size()
{
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__GFX9__)
return 64;
#else
return 32;
#endif
#else
return 64;
#endif
}
inline __host__ index_t get_warp_size()
{
#if !(defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC))
int device = 0;
int result = 0;
auto status = hipGetDevice(&device);
if(status == hipSuccess)
{
status = hipDeviceGetAttribute(&result, hipDeviceAttributeWarpSize, device);
if(status == hipSuccess)
{
return result;
}
}
#endif
return 64;
}
#else
__host__ __device__ constexpr index_t get_warp_size()
{
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
@@ -15,6 +47,7 @@ __host__ __device__ constexpr index_t get_warp_size()
return 32;
#endif
}
#endif
__device__ index_t get_thread_local_1d_id() { return threadIdx.x; }

Some files were not shown because too many files have changed in this diff Show More