mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Merge remote-tracking branch 'origin/develop' into samremes/double_buffer_fp8_ab_scale
This commit is contained in:
112
.github/scripts/therock_configure_ci.py
vendored
Normal file
112
.github/scripts/therock_configure_ci.py
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Iterable, Optional, Mapping
|
||||
|
||||
def gha_set_output(vars: Mapping[str, str | Path]):
|
||||
"""Sets values in a step's output parameters.
|
||||
|
||||
This appends to the file located at the $GITHUB_OUTPUT environment variable.
|
||||
|
||||
See
|
||||
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
|
||||
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
|
||||
"""
|
||||
print(f"Setting github output:\n{vars}")
|
||||
|
||||
step_output_file = os.getenv("GITHUB_OUTPUT")
|
||||
if not step_output_file:
|
||||
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
|
||||
return
|
||||
|
||||
with open(step_output_file, "a") as f:
|
||||
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
|
||||
|
||||
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
|
||||
"""Returns the paths of modified files relative to the base reference."""
|
||||
try:
|
||||
return subprocess.run(
|
||||
["git", "diff", "--name-only", base_ref],
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
).stdout.splitlines()
|
||||
except TimeoutError:
|
||||
print(
|
||||
"Computing modified files timed out. Not using PR diff to determine"
|
||||
" jobs to run.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
# Paths matching any of these patterns are considered to have no influence over
|
||||
# build or test workflows so any related jobs can be skipped if all paths
|
||||
# modified by a commit/PR match a pattern in this list.
|
||||
SKIPPABLE_PATH_PATTERNS = [
|
||||
"docs/*",
|
||||
"*.gitignore",
|
||||
"*.md",
|
||||
"*.pre-commit-config.*",
|
||||
"*LICENSE",
|
||||
'Jenkinsfile',
|
||||
'.github/ISSUE_TEMPLATE/*',
|
||||
'.github/CODEOWNERS',
|
||||
'.github/*.md',
|
||||
'.github/dependabot.yml',
|
||||
]
|
||||
|
||||
def is_path_skippable(path: str) -> bool:
|
||||
"""Determines if a given relative path to a file matches any skippable patterns."""
|
||||
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
|
||||
|
||||
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if at least one path is not in the skippable set."""
|
||||
if paths is None:
|
||||
return False
|
||||
return any(not is_path_skippable(p) for p in paths)
|
||||
|
||||
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
|
||||
"""Returns true if CI workflows should run given a list of modified paths."""
|
||||
|
||||
if paths is None:
|
||||
print("No files were modified, skipping TheRock CI jobs")
|
||||
return False
|
||||
|
||||
paths_set = set(paths)
|
||||
github_workflows_paths = set(
|
||||
[p for p in paths if p.startswith(".github/workflows")]
|
||||
)
|
||||
other_paths = paths_set - github_workflows_paths
|
||||
|
||||
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
|
||||
|
||||
print("should_ci_run_given_modified_paths findings:")
|
||||
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
|
||||
|
||||
if contains_other_non_skippable_files:
|
||||
print("Enabling TheRock CI jobs since a non-skippable path was modified")
|
||||
return True
|
||||
else:
|
||||
print(
|
||||
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
|
||||
)
|
||||
return False
|
||||
|
||||
def main(args):
|
||||
base_ref = args.get("base_ref")
|
||||
modified_paths = get_modified_paths(base_ref)
|
||||
print("modified_paths (max 200):", modified_paths[:200])
|
||||
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
|
||||
output = {
|
||||
'enable_therock_ci': json.dumps(enable_jobs)
|
||||
}
|
||||
gha_set_output(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = {}
|
||||
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
|
||||
main(args)
|
||||
8
.github/workflows/therock-ci-linux.yml
vendored
8
.github/workflows/therock-ci-linux.yml
vendored
@@ -21,9 +21,11 @@ jobs:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
|
||||
options: -v /runner/config:/home/awsconfig/
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
TEATIME_FORCE_INTERACTIVE: 0
|
||||
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
|
||||
steps:
|
||||
- name: Checkout composable_kernel repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
@@ -83,9 +85,9 @@ jobs:
|
||||
echo "----------"
|
||||
du -h -d 1 TheRock/build/artifacts
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
- name: Configure AWS Credentials for non-forked repos
|
||||
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
|
||||
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
|
||||
with:
|
||||
aws-region: us-east-2
|
||||
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
|
||||
|
||||
31
.github/workflows/therock-ci.yml
vendored
31
.github/workflows/therock-ci.yml
vendored
@@ -5,6 +5,15 @@ on:
|
||||
branches:
|
||||
- develop
|
||||
workflow_dispatch:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
branches:
|
||||
- mainline
|
||||
- release/*
|
||||
- release-staging/*
|
||||
- develop
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -18,8 +27,29 @@ concurrency:
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
# The commit being checked out is the merge commit for a PR. Its first
|
||||
# parent will be the tip of the base branch.
|
||||
BASE_REF: HEAD^
|
||||
outputs:
|
||||
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
|
||||
with:
|
||||
# We need the parent commit to do a diff
|
||||
fetch-depth: 2
|
||||
|
||||
- name: "Configuring CI options"
|
||||
id: configure
|
||||
run: python .github/scripts/therock_configure_ci.py
|
||||
|
||||
therock-ci-linux:
|
||||
name: TheRock CI Linux
|
||||
needs: setup
|
||||
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
@@ -34,6 +64,7 @@ jobs:
|
||||
name: TheRock CI Summary
|
||||
if: always()
|
||||
needs:
|
||||
- setup
|
||||
- therock-ci-linux
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
|
||||
1
.github/workflows/therock-test-packages.yml
vendored
1
.github/workflows/therock-test-packages.yml
vendored
@@ -68,6 +68,7 @@ jobs:
|
||||
VENV_DIR: ${{ env.VENV_DIR }}
|
||||
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
|
||||
|
||||
- name: Test
|
||||
timeout-minutes: ${{ matrix.components.timeout_minutes }}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 7.0.0
|
||||
## Composable Kernel 1.2.0 for ROCm 7.0.0
|
||||
|
||||
### Added
|
||||
|
||||
@@ -27,6 +27,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added support for elementwise kernel.
|
||||
* Added benchmarking support for tile engine GEMM Multi D.
|
||||
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
|
||||
|
||||
### Optimized
|
||||
|
||||
@@ -48,6 +49,7 @@ None
|
||||
* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced.
|
||||
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
|
||||
|
||||
### Known issues
|
||||
|
||||
|
||||
@@ -16,12 +16,21 @@ else()
|
||||
"Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel.")
|
||||
endif()
|
||||
|
||||
# Allow user to specify the C++ standard.
|
||||
# We must support C++17 builds until downstream users are migrated to C++20, but we default to C++20.
|
||||
set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)")
|
||||
set(valid_cxx_standards 17 20)
|
||||
set_property(CACHE CK_CXX_STANDARD PROPERTY STRINGS ${valid_cxx_standards})
|
||||
if(NOT CK_CXX_STANDARD IN_LIST valid_cxx_standards)
|
||||
message(FATAL_ERROR "CK_CXX_STANDARD must be one of ${valid_cxx_standards}")
|
||||
endif()
|
||||
|
||||
# Default installation path
|
||||
if(NOT WIN32)
|
||||
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
|
||||
endif()
|
||||
|
||||
set(version 1.1.0)
|
||||
set(version 1.2.0)
|
||||
# Check support for CUDA/HIP in Cmake
|
||||
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
|
||||
include(CTest)
|
||||
@@ -221,11 +230,20 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
endif()
|
||||
|
||||
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
|
||||
set(CK_TILE_USE_WMMA 0)
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
set(CK_TILE_USE_WMMA 1)
|
||||
endif()
|
||||
|
||||
# define the macro with the current value (0 or 1)
|
||||
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_WMMA_FP8)
|
||||
@@ -324,32 +342,19 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_GFX11)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
|
||||
message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}")
|
||||
endif()
|
||||
|
||||
if(ENABLE_ASM_DUMP)
|
||||
add_compile_options(--save-temps)
|
||||
add_compile_options(-Wno-gnu-line-marker)
|
||||
message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12"))
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
add_compile_definitions(CK_TILE_WAVE32_ENABLED)
|
||||
message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}")
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
link_libraries(Threads::Threads)
|
||||
|
||||
## C++
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD ${CK_CXX_STANDARD})
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}")
|
||||
|
||||
23
Dockerfile.pytorch
Normal file
23
Dockerfile.pytorch
Normal file
@@ -0,0 +1,23 @@
|
||||
ARG BASE_DOCKER="rocm/pytorch-nightly:latest"
|
||||
FROM $BASE_DOCKER
|
||||
ARG CK_PYTORCH_BRANCH="develop"
|
||||
RUN groupadd -g 109 render && \
|
||||
usermod -u 1001 jenkins && \
|
||||
groupmod -g 1001 jenkins && \
|
||||
cd /tmp/pytorch && \
|
||||
rm -rf build && \
|
||||
cd /tmp/pytorch/third_party && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/aiter/3rdparty && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/fbgemm/external && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
cd /tmp/pytorch/third_party/flash-attention/csrc && \
|
||||
rm -rf composable_kernel && \
|
||||
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
|
||||
chown -R jenkins:jenkins /tmp/pytorch && \
|
||||
chmod -R a+rwx /tmp/pytorch && \
|
||||
sudo usermod -aG irc jenkins
|
||||
237
Jenkinsfile
vendored
237
Jenkinsfile
vendored
@@ -192,12 +192,16 @@ def buildDocker(install_prefix){
|
||||
image_name = "rocm/composable_kernel:ck_aiter"
|
||||
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . "
|
||||
}
|
||||
else{
|
||||
else if(params.RUN_PYTORCH_TESTS){
|
||||
image_name = "rocm/composable_kernel:ck_pytorch"
|
||||
dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . "
|
||||
}
|
||||
else{
|
||||
dockerArgs = dockerArgs + " -f Dockerfile . "
|
||||
}
|
||||
echo "Build Args: ${dockerArgs}"
|
||||
try{
|
||||
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS){
|
||||
if(params.BUILD_DOCKER || params.RUN_AITER_TESTS || params.RUN_PYTORCH_TESTS){
|
||||
//force building the new docker if that parameter is true
|
||||
echo "Building image: ${image_name}"
|
||||
retimage = docker.build("${image_name}", dockerArgs)
|
||||
@@ -400,8 +404,9 @@ def cmake_build(Map conf=[:]){
|
||||
echo "Build packages"
|
||||
sh 'ninja -j64 package'
|
||||
archiveArtifacts artifacts: 'composablekernel-dev*.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages"
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.2.0_amd64.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.2.0_amd64.deb'
|
||||
stash includes: "composablekernel-**.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
else{
|
||||
@@ -571,50 +576,66 @@ def Build_CK(Map conf=[:]){
|
||||
python3 -m pytest python/test/test_gen_instances.py
|
||||
"""
|
||||
}
|
||||
dir("build"){
|
||||
if (params.RUN_FULL_QA && arch == 2 ){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'ninja package'
|
||||
archiveArtifacts artifacts: 'composablekernel*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb'
|
||||
stash includes: "composablekernel-**.deb", name: "packages"
|
||||
}
|
||||
}
|
||||
// run performance tests, stash the logs, results will be processed on the master node
|
||||
dir("script"){
|
||||
if (params.RUN_PERFORMANCE_TESTS){
|
||||
if (params.RUN_FULL_QA && arch == 1){
|
||||
// run full tests on gfx90a
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm.log"
|
||||
archiveArtifacts "perf_resnet50_N256.log"
|
||||
archiveArtifacts "perf_resnet50_N4.log"
|
||||
archiveArtifacts "perf_batched_gemm.log"
|
||||
archiveArtifacts "perf_grouped_gemm.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight.log"
|
||||
archiveArtifacts "perf_gemm_bilinear.log"
|
||||
archiveArtifacts "perf_reduction.log"
|
||||
archiveArtifacts "perf_splitK_gemm.log"
|
||||
archiveArtifacts "perf_onnx_gemm.log"
|
||||
archiveArtifacts "perf_mixed_gemm.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
|
||||
archiveArtifacts "perf_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
|
||||
archiveArtifacts "perf_batched_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data_gfx90a.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx90a.log"
|
||||
archiveArtifacts "perf_gemm_bilinear_gfx90a.log"
|
||||
archiveArtifacts "perf_reduction_gfx90a.log"
|
||||
archiveArtifacts "perf_splitK_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_mixed_gemm_gfx90a.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx90a"
|
||||
}
|
||||
if (params.RUN_FULL_QA && arch == 2){
|
||||
// run full tests on gfx942
|
||||
echo "Run full performance tests"
|
||||
sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
|
||||
archiveArtifacts "perf_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx942.log"
|
||||
archiveArtifacts "perf_batched_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_fwd_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_data_gfx942.log"
|
||||
archiveArtifacts "perf_grouped_conv_bwd_weight_gfx942.log"
|
||||
archiveArtifacts "perf_gemm_bilinear_gfx942.log"
|
||||
archiveArtifacts "perf_reduction_gfx942.log"
|
||||
archiveArtifacts "perf_splitK_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_mixed_gemm_gfx942.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx942"
|
||||
}
|
||||
else if ( arch == 1 ){
|
||||
// run standard tests on gfx90a
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}"
|
||||
archiveArtifacts "perf_gemm.log"
|
||||
archiveArtifacts "perf_onnx_gemm.log"
|
||||
archiveArtifacts "perf_resnet50_N256.log"
|
||||
archiveArtifacts "perf_resnet50_N4.log"
|
||||
stash includes: "perf_**.log", name: "perf_log"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx90a"
|
||||
archiveArtifacts "perf_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx90a.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx90a.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx90a"
|
||||
}
|
||||
else if ( arch == 2 ){
|
||||
// run standard tests on gfx942
|
||||
echo "Run performance tests"
|
||||
sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx942"
|
||||
archiveArtifacts "perf_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_onnx_gemm_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N256_gfx942.log"
|
||||
archiveArtifacts "perf_resnet50_N4_gfx942.log"
|
||||
stash includes: "perf_**.log", name: "perf_log_gfx942"
|
||||
}
|
||||
// disable performance tests on gfx1030 for now.
|
||||
//else if ( arch == 3){
|
||||
@@ -732,29 +753,64 @@ def process_results(Map conf=[:]){
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx942"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs for gfx942: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_fmha_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the FMHA performance logs: ${err.getMessage()}."
|
||||
echo "could not locate the FMHA performance logs for gfx90a: ${err.getMessage()}."
|
||||
}
|
||||
}
|
||||
if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){
|
||||
if (params.BUILD_INSTANCES_ONLY){
|
||||
// unstash deb packages
|
||||
unstash "packages"
|
||||
sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/"
|
||||
}
|
||||
else{
|
||||
// unstash perf files to master
|
||||
unstash "perf_log"
|
||||
try{
|
||||
unstash "perf_log_gfx90a"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx90a performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx942"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx942 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx950"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx950 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx908"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx908 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
unstash "perf_log_gfx11"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the gfx11 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
try{
|
||||
|
||||
unstash "perf_log_gfx12"
|
||||
}
|
||||
catch(Exception err){
|
||||
echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}."
|
||||
echo "could not locate the gfx12 performance logs: ${err.getMessage()}."
|
||||
}
|
||||
sh "./process_perf_data.sh"
|
||||
}
|
||||
// process the logs
|
||||
sh "./process_perf_data.sh"
|
||||
}
|
||||
}
|
||||
catch(e){
|
||||
@@ -819,13 +875,64 @@ def run_aiter_tests(Map conf=[:]){
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def run_pytorch_tests(Map conf=[:]){
|
||||
show_node_info()
|
||||
env.HSA_ENABLE_SDMA=0
|
||||
checkout scm
|
||||
//use the latest pytorch-nightly image
|
||||
def image = "rocm/composable_kernel:ck_pytorch"
|
||||
def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins"
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} "
|
||||
echo "Docker flags: ${dockerOpts}"
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try
|
||||
{
|
||||
echo "Pulling image: ${image}"
|
||||
retimage = docker.image("${image}")
|
||||
withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) {
|
||||
retimage.pull()
|
||||
}
|
||||
}
|
||||
catch(Exception ex)
|
||||
{
|
||||
error "Unable to locate image: ${image}"
|
||||
}
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 45, unit: 'MINUTES'){
|
||||
try{
|
||||
sh "rocminfo"
|
||||
sh "python3 --version"
|
||||
sh "python3 /tmp/pytorch/tools/amd_build/build_amd.py"
|
||||
sh "USE_ROCM_CK_SDPA=1 PYTORCH_ROCM_ARCH=gfx942 python /tmp/pytorch/setup.py develop"
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while building Pytorch"
|
||||
echo 'Exception occurred: ' + e.toString()
|
||||
throw e
|
||||
}
|
||||
finally{
|
||||
echo "Finished building Pytorch"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : ""
|
||||
0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false
|
||||
0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : ""
|
||||
|
||||
pipeline {
|
||||
agent none
|
||||
@@ -960,6 +1067,14 @@ pipeline {
|
||||
name: "RUN_ALL_UNIT_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Run all unit tests (default: OFF)")
|
||||
booleanParam(
|
||||
name: "RUN_PYTORCH_TESTS",
|
||||
defaultValue: false,
|
||||
description: "Try building PYTORCH with latest CK develop branch (default: OFF)")
|
||||
string(
|
||||
name: 'ck_pytorch_branch',
|
||||
defaultValue: 'develop',
|
||||
description: 'Specify which branch of CK to test with Pytorch (default: develop)')
|
||||
booleanParam(
|
||||
name: "RUN_AITER_TESTS",
|
||||
defaultValue: false,
|
||||
@@ -1051,6 +1166,24 @@ pipeline {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run Pytorch Tests")
|
||||
{
|
||||
parallel
|
||||
{
|
||||
stage("Run Pytorch Tests on gfx942")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PYTORCH_TESTS.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("gfx942")}
|
||||
steps{
|
||||
run_pytorch_tests()
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stage("Run AITER Tests")
|
||||
{
|
||||
@@ -1107,11 +1240,16 @@ pipeline {
|
||||
agent{ label rocmnode("gfx90a")}
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ cd test_data && \
|
||||
./generate_test_dataset.sh && \
|
||||
cd ../script && \
|
||||
execute_args = """ cd ../build && \
|
||||
../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_dataset_xdl && \
|
||||
cd ../test_data && \
|
||||
# Dataset generation modes:
|
||||
# - small: ~60 test cases (minimal, quick testing - 3 models, 2 batch sizes, 2 image sizes)
|
||||
# - half: ~300 test cases (moderate coverage - 16 models, 3 batch sizes, 5 image sizes), ~ 17 hours testing time
|
||||
# - full: ~600 test cases (comprehensive - 16 models, 5 batch sizes, 9 image sizes), ~ 40 hours testing time
|
||||
./generate_test_dataset.sh half && \
|
||||
cd ../build && \
|
||||
./bin/test_grouped_convnd_fwd_dataset_xdl"""
|
||||
}
|
||||
steps{
|
||||
@@ -1306,6 +1444,7 @@ pipeline {
|
||||
def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_rhel8_rocm6.3"
|
||||
setup_args = """ -DGPU_TARGETS="gfx942" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " \
|
||||
-DCK_CXX_STANDARD="17" \
|
||||
-DCK_USE_ALTERNATIVE_PYTHON=/opt/Python-3.8.13/bin/python3.8 """
|
||||
execute_args = " "
|
||||
}
|
||||
@@ -1440,7 +1579,7 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """
|
||||
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0")
|
||||
}
|
||||
cleanWs()
|
||||
}
|
||||
@@ -1517,7 +1656,7 @@ pipeline {
|
||||
stage("Process results"){
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean() || params.RUN_CK_TILE_FMHA_TESTS.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() }
|
||||
}
|
||||
agent { label 'mici' }
|
||||
steps{
|
||||
|
||||
@@ -19,7 +19,6 @@ Getting started
|
||||
build the library. You can also find some of this information in the
|
||||
`README file <https://github.com/ROCm/composable_kernel/blob/develop/README.md>`_
|
||||
on the project's GitHub page.
|
||||
#. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_ provides a deeper understanding of the CK library and showcases its performance capabilities.
|
||||
<https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_
|
||||
from the AMD Community portal. It offers a deeper understanding of the library's objectives and showcases its performance capabilities.
|
||||
#. **General information:** For broader information about AMD products, consider exploring the
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
@@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
try
|
||||
{
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -144,6 +144,28 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
|
||||
|
||||
# add fmha_fwd_v3 example
|
||||
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
|
||||
|
||||
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
|
||||
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
|
||||
)
|
||||
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
|
||||
fmha_fwd_v3.cpp
|
||||
${FMHA_FWD_V3_INSTANCES}
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-fgpu-flush-denormals-to-zero
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_fmha_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_fmha_fwd`
|
||||
|
||||
@@ -110,9 +110,9 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
@@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_co
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
@@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_confi
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::strea
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
|
||||
@@ -110,9 +110,9 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -385,7 +385,7 @@ class FmhaFwdApiPool:
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
|
||||
@@ -60,9 +60,9 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
@@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
|
||||
@@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
492
example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp
Normal file
492
example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp
Normal file
@@ -0,0 +1,492 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include <ck_tile/core/numeric/math.hpp>
|
||||
#include <ck_tile/core/utility/functional.hpp>
|
||||
#include <ck_tile/host/arg_parser.hpp>
|
||||
#include <ck_tile/host/device_memory.hpp>
|
||||
#include <ck_tile/host/fill.hpp>
|
||||
#include <ck_tile/host/check_err.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_masking.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
|
||||
|
||||
#include "fmha_fwd.hpp"
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s", "3328", "seqlen_q")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q & k")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"0",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("v", "1", "0:no verify, 1:verify")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
enum class TensorLayout
|
||||
{
|
||||
bhsd,
|
||||
bshd,
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case TensorLayout::bhsd: return stream << "bhsd";
|
||||
case TensorLayout::bshd: return stream << "bshd";
|
||||
default: return stream << "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
explicit Problem(const ck_tile::ArgParser& args)
|
||||
{
|
||||
data_type = args.get_str("prec") == "fp16"
|
||||
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
|
||||
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
|
||||
batch = args.get_int("b");
|
||||
seqlen_q = args.get_int("s");
|
||||
seqlen_k = args.get_int("s_k");
|
||||
if(seqlen_k < 0)
|
||||
{
|
||||
seqlen_k = seqlen_q;
|
||||
}
|
||||
nhead_q = args.get_int("h");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
if(nhead_kv < 0)
|
||||
{
|
||||
nhead_kv = nhead_q;
|
||||
}
|
||||
hdim = args.get_int("d");
|
||||
softmax_scale = args.get_float("scale_s");
|
||||
if(softmax_scale == .0f)
|
||||
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_query_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_q, seqlen_q, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_q, nhead_q, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_key_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_kv, seqlen_k, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_k, nhead_kv, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_value_shape() const
|
||||
{
|
||||
if(input_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_kv, seqlen_k, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_k, nhead_kv, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_output_shape() const
|
||||
{
|
||||
if(output_layout == TensorLayout::bhsd)
|
||||
{
|
||||
return {batch, nhead_q, seqlen_q, hdim};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {batch, seqlen_q, nhead_q, hdim};
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_kv;
|
||||
ck_tile::index_t hdim;
|
||||
float softmax_scale;
|
||||
mask_info mask;
|
||||
TensorLayout input_layout;
|
||||
TensorLayout output_layout;
|
||||
};
|
||||
|
||||
struct RunConfig
|
||||
{
|
||||
explicit RunConfig(const ck_tile::ArgParser& args)
|
||||
{
|
||||
seed = args.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
kernel_warmup = args.get_int("warmup");
|
||||
kernel_repeat = args.get_int("repeat");
|
||||
verify = args.get_bool("v");
|
||||
}
|
||||
|
||||
std::optional<uint32_t> seed;
|
||||
int kernel_warmup;
|
||||
int kernel_repeat;
|
||||
bool verify;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
auto generate_qkv(const Problem& problem,
|
||||
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
|
||||
-> std::tuple<ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>>
|
||||
{
|
||||
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
|
||||
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
|
||||
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
|
||||
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
|
||||
|
||||
return std::make_tuple(q, k, v);
|
||||
}
|
||||
|
||||
namespace host {
|
||||
template <typename AccDataType,
|
||||
typename PDataType,
|
||||
typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename QElementOp,
|
||||
typename KElementOp,
|
||||
typename VElementOp,
|
||||
typename SAccElementOp>
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
const VElementOp& v_element_op = {},
|
||||
const SAccElementOp& s_acc_element_op = {})
|
||||
{
|
||||
const int batch_size = q_bshd.mDesc.get_lengths()[0];
|
||||
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
|
||||
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
|
||||
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
|
||||
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
|
||||
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
|
||||
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
|
||||
|
||||
const int nr = nhead_q / nhead_kv;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
|
||||
// do computation for each batch
|
||||
for(int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
// copy per-batch data from input tensors
|
||||
// clang-format off
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
|
||||
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
|
||||
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// clang-format on
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
|
||||
}
|
||||
else if(mask.type == mask_enum::window_generic)
|
||||
{
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left, mask.right, seqlen_q, seqlen_kv));
|
||||
}
|
||||
else
|
||||
{
|
||||
// if left window size is negative, means causal
|
||||
// else means generic (for current batch)
|
||||
if(mask.left < 0)
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
else
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
|
||||
mask.left,
|
||||
mask.right,
|
||||
seqlen_q,
|
||||
seqlen_kv,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
}
|
||||
}
|
||||
} // namespace host
|
||||
|
||||
template <typename DataType>
|
||||
bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
{
|
||||
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
|
||||
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
|
||||
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q.data());
|
||||
k_buf.ToDevice(k.data());
|
||||
v_buf.ToDevice(v.data());
|
||||
|
||||
ck_tile::fmha_fwd_v3_args args;
|
||||
|
||||
args.data_type = problem.data_type;
|
||||
args.batch = problem.batch;
|
||||
args.seqlen_q = problem.seqlen_q;
|
||||
args.seqlen_k = problem.seqlen_k;
|
||||
args.nhead_q = problem.nhead_q;
|
||||
args.nhead_kv = problem.nhead_kv;
|
||||
args.hdim_qk = problem.hdim;
|
||||
args.hdim_v = problem.hdim;
|
||||
args.softmax_scale = problem.softmax_scale;
|
||||
|
||||
args.window_size_left = problem.mask.left;
|
||||
args.window_size_right = problem.mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
|
||||
|
||||
// bshd: (batch, seqlen_q, nhead_q, hdim)
|
||||
// bhsd: (batch, nhead_q, seqlen_q, hdim)
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.stride_q =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_q =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
|
||||
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_k, nhead_kv, hdim)
|
||||
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.stride_k =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_k =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
|
||||
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_k, nhead_kv, hdim)
|
||||
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.stride_v =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_v =
|
||||
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
|
||||
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
|
||||
|
||||
// bshd: (batch, seqlen_q, nhead_q, hdim)
|
||||
// bhsd: (batch, nhead_q, seqlen_q, hdim)
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
args.stride_o =
|
||||
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
|
||||
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
|
||||
? problem.hdim
|
||||
: problem.seqlen_q * problem.hdim;
|
||||
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/*log_level=*/0,
|
||||
run_config.kernel_warmup,
|
||||
run_config.kernel_repeat};
|
||||
|
||||
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t flop = [&] {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
{
|
||||
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
else
|
||||
{
|
||||
/// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
}();
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
if(problem.input_layout == problem.output_layout)
|
||||
{
|
||||
std::cout << problem.input_layout;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << problem.input_layout << "-" << problem.output_layout;
|
||||
}
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
|
||||
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
|
||||
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
|
||||
<< " TFlops" << std::endl;
|
||||
|
||||
if(!run_config.verify)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// transpose tensor descriptors from bhsd to bshd if necessary
|
||||
if(problem.input_layout != TensorLayout::bshd)
|
||||
{
|
||||
q = q.transpose({0, 2, 1, 3});
|
||||
k = k.transpose({0, 2, 1, 3});
|
||||
v = v.transpose({0, 2, 1, 3});
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
if(problem.output_layout != TensorLayout::bshd)
|
||||
{
|
||||
o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
}
|
||||
|
||||
host::fmha_fwd<float, DataType>(q,
|
||||
k,
|
||||
v,
|
||||
problem.mask,
|
||||
o_ref,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [parse_result, args] = parse_cmd_args(argc, argv);
|
||||
if(!parse_result)
|
||||
{
|
||||
std::cerr << "failed to parse command line arguments" << std::endl;
|
||||
}
|
||||
|
||||
Problem problem(args);
|
||||
RunConfig run_config(args);
|
||||
|
||||
const auto run = [&] {
|
||||
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
|
||||
{
|
||||
return run_impl<ck_tile::fp16_t>(problem, run_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_impl<ck_tile::bf16_t>(problem, run_config);
|
||||
}
|
||||
};
|
||||
|
||||
return !run();
|
||||
}
|
||||
@@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
|
||||
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
|
||||
fmha_traits.hdim_q,
|
||||
fmha_traits.hdim_v,
|
||||
fmha_traits.data_type.c_str(),
|
||||
fmha_traits.is_group_mode,
|
||||
static_cast<int>(fmha_traits.mask_type),
|
||||
static_cast<int>(fmha_traits.bias_type),
|
||||
fmha_traits.has_dbias,
|
||||
fmha_traits.has_dropout,
|
||||
fmha_traits.is_store_randval,
|
||||
fmha_traits.is_deterministic);
|
||||
fflush(stdout);
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
60
example/ck_tile/01_fmha/fmha_fwd_v3.cpp
Normal file
60
example/ck_tile/01_fmha/fmha_fwd_v3.cpp
Normal file
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
|
||||
{
|
||||
switch(data_type)
|
||||
{
|
||||
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
|
||||
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
|
||||
default: return stream << "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
}
|
||||
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
|
||||
{
|
||||
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
|
||||
|
||||
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
67
example/ck_tile/01_fmha/fmha_fwd_v3.hpp
Normal file
67
example/ck_tile/01_fmha/fmha_fwd_v3.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct fmha_fwd_v3_args
|
||||
{
|
||||
enum class data_type_enum
|
||||
{
|
||||
fp16,
|
||||
bf16
|
||||
};
|
||||
|
||||
data_type_enum data_type;
|
||||
// bool is_varlen;
|
||||
|
||||
index_t batch;
|
||||
index_t seqlen_q;
|
||||
index_t seqlen_k;
|
||||
index_t nhead_q;
|
||||
index_t nhead_kv;
|
||||
index_t hdim_qk;
|
||||
index_t hdim_v;
|
||||
|
||||
float softmax_scale;
|
||||
|
||||
index_t window_size_left;
|
||||
index_t window_size_right;
|
||||
index_t mask_type;
|
||||
|
||||
const void* q_ptr;
|
||||
index_t stride_q;
|
||||
index_t nhead_stride_q;
|
||||
index_t batch_stride_q;
|
||||
|
||||
const void* k_ptr;
|
||||
index_t stride_k;
|
||||
index_t nhead_stride_k;
|
||||
index_t batch_stride_k;
|
||||
|
||||
const void* v_ptr;
|
||||
index_t stride_v;
|
||||
index_t nhead_stride_v;
|
||||
index_t batch_stride_v;
|
||||
|
||||
void* o_ptr;
|
||||
index_t stride_o;
|
||||
index_t nhead_stride_o;
|
||||
index_t batch_stride_o;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
159
example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp
Normal file
159
example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp
Normal file
@@ -0,0 +1,159 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
|
||||
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
|
||||
const fmha_fwd_v3_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair(true, \
|
||||
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType>
|
||||
struct fmha_fwd_v3_problem_traits;
|
||||
|
||||
template <>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::half_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::half_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::bf16_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::bf16_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
|
||||
struct fmha_fwd_v3_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// M0 N0 K0 N1 K1
|
||||
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
|
||||
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using fmha_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using fmha_shape = TileFmhaShape<fmha_block_tile,
|
||||
fmha_block_warps,
|
||||
fmha_warp_gemm_shape,
|
||||
fmha_block_warps,
|
||||
fmha_warp_gemm_shape,
|
||||
true // IsVLayoutRowMajor
|
||||
>;
|
||||
|
||||
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
|
||||
true, // kPadSeqLenK
|
||||
false, // kPadHeadDimQ
|
||||
false, // kPadHeadDimV
|
||||
false, // kStoreLSE
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using fmha_mask = SimplifiedGenericAttentionMask<IsMasking>;
|
||||
|
||||
using fmha_pipeline_problem =
|
||||
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
|
||||
fmha_shape,
|
||||
IsVariableSeqlen,
|
||||
fmha_mask,
|
||||
fmha_traits>;
|
||||
|
||||
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
|
||||
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
|
||||
true, // kPadM
|
||||
true, // kPadM
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
nullptr, // lse_ptr
|
||||
args.o_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_qk,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_kv,
|
||||
args.softmax_scale,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_v,
|
||||
args.stride_o,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.nhead_stride_v,
|
||||
0, // nhead_stride_lse
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
0, // batch_stride_lse
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
template <typename KernelTraits>
|
||||
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "fmha_fwd_v3_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
|
||||
|
||||
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
31
example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh
Executable file
31
example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh
Executable file
@@ -0,0 +1,31 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for causal in 0 1 ; do
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
if [ $causal -eq 0 ]; then
|
||||
mask=0
|
||||
else
|
||||
mask=b:-1,0
|
||||
fi
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -9,6 +9,8 @@
|
||||
# host name : $hostname
|
||||
# gpu architecture: e.g., gfx90a, or gfx942, etc.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
#get the command line arguments:
|
||||
export env_type=$1
|
||||
echo 'Environment type: ' $env_type
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#!/bin/sh
|
||||
#!/bin/bash
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
@@ -17,12 +19,12 @@ for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#!/bin/bash
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
@@ -51,19 +53,18 @@ run_fp16_bf16_tests() {
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ;
|
||||
}
|
||||
|
||||
run_fp8_tests() {
|
||||
|
||||
@@ -42,7 +42,7 @@ return hidden_states, per_token_scale
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_layernorm2d_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
|
||||
|
||||
@@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
@@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_basic -j
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
|
||||
@@ -26,6 +26,15 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
@@ -33,6 +42,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -65,7 +75,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -81,8 +90,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -100,10 +109,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
@@ -208,7 +208,6 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -232,7 +231,7 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -279,15 +278,13 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
@@ -373,7 +370,7 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
|
||||
4
example/ck_tile/03_gemm/gemm_utils.hpp
Executable file → Normal file
4
example/ck_tile/03_gemm/gemm_utils.hpp
Executable file → Normal file
@@ -172,6 +172,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
@@ -192,6 +193,7 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
@@ -484,7 +486,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -126,7 +125,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -172,15 +171,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -103,7 +103,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -127,7 +126,7 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -173,15 +172,13 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
@@ -338,7 +335,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder contains example for Image to Column using ck_tile tile-programming
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_img2col -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_img2col`
|
||||
|
||||
@@ -55,13 +55,12 @@ float image_to_column(const image_to_column_traits& traits,
|
||||
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
|
||||
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
|
||||
args.G);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 2;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream_conf,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
stream_conf, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -94,18 +94,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
throw std::runtime_error("Wrong! Arguments not supported!\n");
|
||||
}
|
||||
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims));
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
input_shape,
|
||||
input_strides,
|
||||
kept_dim,
|
||||
reduce_dims));
|
||||
|
||||
std::size_t num_btype = sizeof(XDataType) * N * C * H * W + sizeof(YDataType) * N * C;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ args:
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_permute -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_permute`
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -115,11 +115,12 @@ struct matrix_core_swizzle_kernel
|
||||
|
||||
__host__ void operator()(const ck_tile::stream_config& s) const
|
||||
{
|
||||
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
|
||||
ck_tile::kentry<1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
|
||||
}
|
||||
|
||||
struct kernel
|
||||
{
|
||||
static constexpr int kBlockSize = BLOCK_SIZE;
|
||||
__device__ static constexpr auto get_src_dist()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
@@ -53,11 +53,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -69,11 +69,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -85,11 +85,11 @@ float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for topk-softmax kernel using ck_tile tile-programm
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_topk_softmax -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_topk_softmax`
|
||||
|
||||
@@ -13,11 +13,11 @@
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
constexpr dim3 blocks = kernel::BlockSize(); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for Rmsnorm2D forward using ck_tile tile-programmin
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_rmsnorm2d_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`
|
||||
|
||||
@@ -138,12 +138,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
@@ -257,7 +257,7 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_add_rmsnorm2d_rdquant_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
|
||||
|
||||
@@ -136,12 +136,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
@@ -66,5 +66,5 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for smoothquant using ck_tile tile-programming impl
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_smoothquant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_smoothquant`
|
||||
|
||||
@@ -126,12 +126,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat};
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ float smoothquant_(const S& s, A a)
|
||||
using Kernel = ck_tile::Smoothquant<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
@@ -58,5 +58,5 @@ float smoothquant_(const S& s, A a)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for moe-sorting kernel using ck_tile tile-programmi
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_sorting -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_sorting`
|
||||
|
||||
@@ -209,7 +209,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
@@ -227,7 +227,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
@@ -283,7 +283,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_size = kernel::GetSmemSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
@@ -334,15 +334,15 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
|
||||
@@ -9,7 +9,7 @@ Unlike standard smoothquant op, the input scale is from different expert `[exper
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_moe_smoothquant -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_moe_smoothquant`
|
||||
|
||||
@@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a)
|
||||
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
@@ -61,5 +61,5 @@ float moe_smoothquant_(const S& s, A a)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
const dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
static int printed = 0;
|
||||
@@ -66,5 +66,5 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
@@ -213,7 +213,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
@@ -231,7 +231,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
|
||||
#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
|
||||
@@ -287,7 +287,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_size = kernel::GetSmemSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
|
||||
}()
|
||||
|
||||
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder contains example for batched GEMM using ck_tile tile-programming imp
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_batched_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_batched_gemm`
|
||||
|
||||
@@ -142,7 +142,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -156,8 +155,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -176,7 +175,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ All the necessary parameters are set, the tiling is computed, the GEMM pipeline
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_grouped_gemm -j
|
||||
```
|
||||
|
||||
@@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
@@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -53,8 +48,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
@@ -82,7 +75,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -92,9 +84,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
@@ -105,7 +97,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
|
||||
@@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the flatmm calculation
|
||||
make tile_example_flatmm_basic -j
|
||||
```
|
||||
|
||||
@@ -101,7 +101,6 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
@@ -119,8 +118,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -171,15 +170,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -40,9 +40,11 @@ template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
|
||||
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
@@ -213,6 +215,16 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 3)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
}
|
||||
else if(init_method == 4)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_host.SetZero();
|
||||
|
||||
@@ -8,7 +8,7 @@ This folder contains example for Multiple D GEMM using ck_tile tile-programming
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_d_fp16 -j
|
||||
```
|
||||
|
||||
@@ -146,7 +146,6 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -160,8 +159,8 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -176,7 +175,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
|
||||
@@ -6,3 +6,6 @@ target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMP
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_bwd_data_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_data_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<1>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<2>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_with_layouts<ck_tile::number<3>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_data_example_prec_type<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -78,7 +78,6 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -98,8 +97,8 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -123,7 +122,7 @@ float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -77,7 +77,6 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -97,8 +96,8 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -120,7 +119,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_til
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_data<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{-1.f, 1.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
|
||||
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
|
||||
}
|
||||
else
|
||||
{
|
||||
weight.SetZero();
|
||||
output.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.SetZero();
|
||||
weight_dev_buf.ToDevice(weight.data());
|
||||
output_dev_buf.ToDevice(output.data());
|
||||
|
||||
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_data<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
input_dev_buf.FromDevice(input.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
|
||||
input_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input_host_ref,
|
||||
weight,
|
||||
output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(input,
|
||||
input_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -167,17 +167,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Run the kernel
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
|
||||
@@ -112,17 +112,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0, // Shared memory
|
||||
op_lengths, // Logical dimensions for the operation (M, N)
|
||||
input_strides, // Strides for input tensor(s)
|
||||
output_strides, // Strides for output tensor (N, M)
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0, // Shared memory
|
||||
op_lengths, // Logical dimensions for the operation (M, N)
|
||||
input_strides, // Strides for input tensor(s)
|
||||
output_strides, // Strides for output tensor (N, M)
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
|
||||
@@ -99,17 +99,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockPerCu>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for batched Transpose using ck_tile tile-programmin
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the transpose executable
|
||||
make tile_example_batched_transpose -j
|
||||
```
|
||||
|
||||
@@ -74,8 +74,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
|
||||
|
||||
auto kargs = kernel::MakeKargs(a);
|
||||
|
||||
const dim3 grids = kernel::GridSize(a);
|
||||
constexpr dim3 blocks = kernel::BlockSize();
|
||||
const dim3 grids = kernel::GridSize(a);
|
||||
const dim3 blocks = kernel::BlockSize();
|
||||
|
||||
printf("Pipeline: %d\n", Config::kPipelineId);
|
||||
printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z);
|
||||
@@ -96,8 +96,8 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con
|
||||
|
||||
printf("Launching Kernel...\n");
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
printf("Kernel finished...\n");
|
||||
|
||||
|
||||
@@ -8,9 +8,8 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
add_executable(tile_example_gemm_aquant_basic EXCLUDE_FROM_ALL gemm_aquant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_aquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_gemm_aquant_preshuffle EXCLUDE_FROM_ALL gemm_aquant_preshuffle.cpp)
|
||||
target_compile_options(tile_example_gemm_aquant_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_executable(tile_example_gemm_bquant_basic EXCLUDE_FROM_ALL gemm_bquant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_bquant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -7,9 +7,10 @@ This folder contains example for Block Scale GEMM using ck_tile tile-programming
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The aquant pipeline method on the gemm calculation
|
||||
make tile_example_gemm_aquant_basic -j
|
||||
make tile_example_gemm_bquant_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_aquant_basic`
|
||||
|
||||
|
||||
@@ -8,11 +8,10 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -21,29 +20,26 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -52,8 +48,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -68,7 +69,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
constexpr bool transposed_warp_gemm = true;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
@@ -82,6 +83,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transposed_warp_gemm,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
@@ -96,7 +98,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -111,8 +112,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
@@ -136,7 +137,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
@@ -187,13 +188,14 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
@@ -201,32 +203,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
@@ -235,4 +223,4 @@ int run_gemm_example(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigComputeV3>(argc, argv); }
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
|
||||
@@ -8,11 +8,10 @@
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -21,8 +20,7 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
@@ -33,17 +31,17 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 1;
|
||||
constexpr ck_tile::index_t N_Warp = 4;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
@@ -52,8 +50,13 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmAQuantTraits<kPadM, kPadN, kPadK, Preshuffle, ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmTraits = ck_tile::TileGemmAQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
@@ -82,6 +85,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transposed_warp_gemm,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
@@ -96,7 +100,6 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
@@ -111,8 +114,8 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
@@ -136,7 +139,7 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
@@ -187,13 +190,14 @@ int run_gemm_example(int argc, char* argv[])
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>{});
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>{});
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
@@ -201,7 +205,7 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
@@ -210,29 +214,18 @@ int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigPreshufle_AQ>(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_gemm_example<GemmConfigPreshuffleQuant>(argc, argv);
|
||||
}
|
||||
|
||||
228
example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp
Normal file
228
example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp
Normal file
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmBQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
GemmConfig::PreshuffleQuant,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenGemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel =
|
||||
ck_tile::BQuantGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
;
|
||||
}
|
||||
|
||||
#include "run_gemm_bquant_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||
argc, argv, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for B.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigDecode>(argc, argv); }
|
||||
@@ -11,11 +11,9 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE 5
|
||||
#define CK_TILE_PIPELINE_PREFILL 1
|
||||
#define CK_TILE_PIPELINE_DECODE 2
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -83,200 +81,37 @@ struct GemmConfigBase
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -288,18 +123,15 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
@@ -314,9 +146,11 @@ struct GemmConfigPreshufle_AQ : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -332,176 +166,6 @@ struct GemmQuantTypeConfig
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::fp8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::pk_int4_t, ck_tile::bf8_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::pk_int4_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, float, float>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using QDataType = float;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
@@ -559,55 +223,6 @@ struct DataTypeTraits<ck_tile::int8_t>
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
@@ -31,7 +32,8 @@ auto shuffle_aq(const ck_tile::HostTensor<T>& t, int block_aq_k)
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -40,8 +42,7 @@ template <typename ADataType,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
bool Preshuffle = false>
|
||||
uint32_t QuantGroupSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
@@ -73,7 +74,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
|
||||
float ave_time = gemm_calc_aquant<ADataType,
|
||||
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -82,8 +84,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
Preshuffle>(
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
@@ -206,7 +207,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
if constexpr(GemmConfig::PreshuffleQuant)
|
||||
{
|
||||
ck_tile::HostTensor<AQDataType> aq_shuffle_host =
|
||||
shuffle_aq(aq_m_aqk, GemmConfig::K_Tile / QuantGroupSize);
|
||||
@@ -222,7 +223,8 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
@@ -231,22 +233,21 @@ int run_gemm_example_with_layouts(int argc,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize,
|
||||
GemmConfig::Preshuffle>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
QuantGroupSize>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
286
example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc
Normal file
286
example/ck_tile/38_block_scale_gemm/run_gemm_bquant_example.inc
Normal file
@@ -0,0 +1,286 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <bit>
|
||||
#include <random>
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_bq(const ck_tile::HostTensor<T>& t, int block_bq_k)
|
||||
{
|
||||
if(t.get_lengths().size() != 2)
|
||||
{
|
||||
throw std::runtime_error("Host tensor is not rank 2 tensor.");
|
||||
}
|
||||
int n_ = t.get_lengths()[0];
|
||||
int bqk_ = t.get_lengths()[1];
|
||||
if(bqk_ % block_bq_k != 0)
|
||||
{
|
||||
throw std::runtime_error("shuffle_aq needs a bqk of multiple times of block_bq_k.");
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_, bqk_ / block_bq_k, block_bq_k});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {1, 0, 2});
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& bq_bqk_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t BQK,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_BQ,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::BQuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.bq_ptr = bq_bqk_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK = BQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_BQ = stride_BQ;
|
||||
|
||||
float ave_time = gemm_calc_bquant<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ADataType, // computeDatatype
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K +
|
||||
sizeof(BQDataType) * BQK * N + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideBQ =" << stride_BQ
|
||||
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
|
||||
<< " A_Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B_Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " BQ_Type = " << DataTypeTraits<BQDataType>::name
|
||||
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using BQDataType = typename TypeConfig::QDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be aligned with QuantGroupSize");
|
||||
}
|
||||
|
||||
ck_tile::index_t BQK = K / QuantGroupSize;
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_BQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_BQ = ck_tile::get_default_stride(BQK, N, stride_BQ, is_row_major(bq_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<BQDataType> bq_bqk_n(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(bq_bqk_n);
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x38)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x22)}(b_k_n);
|
||||
ck_tile::FillConstant<BQDataType>{static_cast<BQDataType>(0.5f)}(bq_bqk_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
bq_bqk_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
QuantGroupSize>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
bq_bqk_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BQK,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_BQ,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
BQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -12,7 +12,7 @@ This experimental kernel is intended for novice CK developers. It introduces the
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture
|
||||
# (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Make the copy kernel executable
|
||||
make tile_example_copy -j
|
||||
```
|
||||
|
||||
@@ -77,10 +77,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// we intentionally do not use pipeline for this example and let the kernel be composite of
|
||||
// Problem and Policy
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = Shape::BlockSize;
|
||||
auto blockSize = Kernel::BlockSize();
|
||||
|
||||
// Print configuration information
|
||||
std::cout << "block size (number of threads per block) " << kBlockSize << std::endl;
|
||||
std::cout << "block size (number of threads per block) " << blockSize << std::endl;
|
||||
std::cout << "wave size (number of threads per wave) " << ck_tile::get_warp_size() << std::endl;
|
||||
std::cout << "block waves (number of waves per block) " << BlockWaves::at(ck_tile::number<0>{})
|
||||
<< " " << BlockWaves::at(ck_tile::number<1>{}) << std::endl;
|
||||
@@ -99,16 +99,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< ")" << std::endl;
|
||||
|
||||
// Launch kernel
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
|
||||
ck_tile::make_kernel<kBlockSize, 1>(Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
float ave_time =
|
||||
launch_kernel(ck_tile::stream_config{nullptr, true, warmup, repeat, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{},
|
||||
kGridSize,
|
||||
blockSize,
|
||||
0,
|
||||
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
|
||||
m,
|
||||
n));
|
||||
|
||||
// Calculate and print performance metrics
|
||||
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m * n;
|
||||
|
||||
@@ -27,8 +27,9 @@ struct TileCopyShape
|
||||
static constexpr index_t ThreadTile_N = ThreadTile::at(number<1>{});
|
||||
|
||||
// Wave tile dimensions
|
||||
static constexpr index_t Wave_Tile_M = WaveTile::at(number<0>{});
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t Wave_Tile_N = WaveTile::at(number<1>{});
|
||||
static constexpr index_t Wave_Tile_M = ThreadTile_M * ThreadTile_N * WaveSize / Wave_Tile_N;
|
||||
|
||||
// Block tile dimensions
|
||||
static constexpr index_t Block_Tile_M = BlockTile::at(number<0>{});
|
||||
@@ -45,7 +46,6 @@ struct TileCopyShape
|
||||
Block_Tile_N / (Waves_Per_Block_N * Wave_Tile_N);
|
||||
|
||||
// Hardware configuration
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
static constexpr index_t BlockSize = Waves_Per_Block_M * Waves_Per_Block_N * WaveSize;
|
||||
|
||||
// Configuration validation
|
||||
@@ -60,8 +60,10 @@ struct TileCopyShape
|
||||
"Invalid wave configuration for N dimension");
|
||||
|
||||
// Ensure wave tile dimensions align with wave size
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(Wave_Tile_M / ThreadTile_M * Wave_Tile_N / ThreadTile_N == WaveSize,
|
||||
"(Wave_Tile_M/ThreadTile_M) * (Wave_Tile_N/ThreadTile_N) != WaveSize");
|
||||
#endif
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -200,6 +202,19 @@ struct ElementWiseTileCopyKernel
|
||||
using XDataType = typename Problem::XDataType;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
if(ck_tile::is_wave32())
|
||||
{
|
||||
return kBlockSize / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return kBlockSize;
|
||||
}
|
||||
}
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
@@ -50,10 +50,11 @@
|
||||
#endif
|
||||
|
||||
// define general macros for various architectures
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || \
|
||||
defined(__gfx9_4_generic__)
|
||||
#define __gfx9__
|
||||
#endif
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx9_4_generic__)
|
||||
#define __gfx94__
|
||||
#endif
|
||||
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -52,10 +52,27 @@ inline std::string get_device_name()
|
||||
}
|
||||
}
|
||||
|
||||
inline bool is_gfx12_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
inline bool is_gfx11_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
|
||||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
|
||||
ck::get_device_name() == "gfx1152";
|
||||
}
|
||||
|
||||
inline bool is_xdl_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" ||
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"
|
||||
#if defined(CK_ENABLE_DYNAMIC_WARP_SIZE)
|
||||
|| is_gfx12_supported() || is_gfx11_supported()
|
||||
#endif
|
||||
;
|
||||
}
|
||||
|
||||
inline bool is_lds_direct_load_supported()
|
||||
@@ -67,7 +84,8 @@ inline bool is_lds_direct_load_supported()
|
||||
|
||||
inline bool is_bf16_atomic_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
|
||||
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" ||
|
||||
is_gfx12_supported();
|
||||
}
|
||||
|
||||
inline bool is_gfx101_supported()
|
||||
@@ -83,18 +101,5 @@ inline bool is_gfx103_supported()
|
||||
ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036";
|
||||
}
|
||||
|
||||
inline bool is_gfx11_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" ||
|
||||
ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" ||
|
||||
ck::get_device_name() == "gfx1152";
|
||||
}
|
||||
|
||||
inline bool is_gfx12_supported()
|
||||
{
|
||||
return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
50
include/ck/library/utility/validation_common.hpp
Normal file
50
include/ck/library/utility/validation_common.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename Layout>
|
||||
inline void
|
||||
validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride")
|
||||
{
|
||||
if(ck::is_same_v<Layout, ck::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if(stride < M)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) +
|
||||
") must be greater than or equal to dim (" + std::to_string(M) + ")");
|
||||
}
|
||||
}
|
||||
else // RowMajor
|
||||
{
|
||||
if(stride < N)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) +
|
||||
") must be greater than or equal to dim (" + std::to_string(N) + ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience functions for common GEMM patterns
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
|
||||
{
|
||||
validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
|
||||
validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
|
||||
validate_gemm_stride<CLayout>(M, N, StrideC, "StrideC");
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
@@ -41,7 +41,9 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
// Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
|
||||
static constexpr index_t WaveSize = 64;
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
|
||||
|
||||
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
|
||||
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
|
||||
@@ -74,9 +76,6 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
|
||||
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
|
||||
|
||||
using HotLoopInstList =
|
||||
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
|
||||
MPerBlock,
|
||||
@@ -219,6 +218,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
@@ -227,6 +227,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
|
||||
"wrong!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -139,9 +139,10 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
@@ -625,13 +626,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
|
||||
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
|
||||
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -141,9 +141,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::WaveSize;
|
||||
|
||||
static constexpr index_t WgpPerCU =
|
||||
(4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
|
||||
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
|
||||
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
|
||||
32768 / WgpPerCU,
|
||||
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user