diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py new file mode 100644 index 0000000000..557afe2d84 --- /dev/null +++ b/.github/scripts/therock_configure_ci.py @@ -0,0 +1,112 @@ +import fnmatch +import json +import os +from pathlib import Path +import subprocess +import sys +from typing import Iterable, Optional, Mapping + +def gha_set_output(vars: Mapping[str, str | Path]): + """Sets values in a step's output parameters. + + This appends to the file located at the $GITHUB_OUTPUT environment variable. + + See + * https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter + * https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs + """ + print(f"Setting github output:\n{vars}") + + step_output_file = os.getenv("GITHUB_OUTPUT") + if not step_output_file: + print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs") + return + + with open(step_output_file, "a") as f: + f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items()) + +def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: + """Returns the paths of modified files relative to the base reference.""" + try: + return subprocess.run( + ["git", "diff", "--name-only", base_ref], + stdout=subprocess.PIPE, + check=True, + text=True, + timeout=60, + ).stdout.splitlines() + except TimeoutError: + print( + "Computing modified files timed out. Not using PR diff to determine" + " jobs to run.", + file=sys.stderr, + ) + return None + +# Paths matching any of these patterns are considered to have no influence over +# build or test workflows so any related jobs can be skipped if all paths +# modified by a commit/PR match a pattern in this list. +SKIPPABLE_PATH_PATTERNS = [ + "docs/*", + "*.gitignore", + "*.md", + "*.pre-commit-config.*", + "*LICENSE", + 'Jenkinsfile', + '.github/ISSUE_TEMPLATE/*', + '.github/CODEOWNERS', + '.github/*.md', + '.github/dependabot.yml', +] + +def is_path_skippable(path: str) -> bool: + """Determines if a given relative path to a file matches any skippable patterns.""" + return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS) + +def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool: + """Returns true if at least one path is not in the skippable set.""" + if paths is None: + return False + return any(not is_path_skippable(p) for p in paths) + +def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: + """Returns true if CI workflows should run given a list of modified paths.""" + + if paths is None: + print("No files were modified, skipping TheRock CI jobs") + return False + + paths_set = set(paths) + github_workflows_paths = set( + [p for p in paths if p.startswith(".github/workflows")] + ) + other_paths = paths_set - github_workflows_paths + + contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) + + print("should_ci_run_given_modified_paths findings:") + print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") + + if contains_other_non_skippable_files: + print("Enabling TheRock CI jobs since a non-skippable path was modified") + return True + else: + print( + "Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs" + ) + return False + +def main(args): + base_ref = args.get("base_ref") + modified_paths = get_modified_paths(base_ref) + print("modified_paths (max 200):", modified_paths[:200]) + enable_jobs = should_ci_run_given_modified_paths(modified_paths) + output = { + 'enable_therock_ci': json.dumps(enable_jobs) + } + gha_set_output(output) + +if __name__ == "__main__": + args = {} + args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1") + main(args) diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml new file mode 100644 index 0000000000..7db124d2a1 --- /dev/null +++ b/.github/workflows/therock-ci-linux.yml @@ -0,0 +1,130 @@ +name: TheRock CI Linux + +on: + workflow_call: + inputs: + cmake_options: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + +permissions: + contents: read + +jobs: + therock-build-linux: + name: Build Linux Packages + runs-on: azure-linux-scale-rocm + permissions: + id-token: write + container: + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292 + options: -v /runner/config:/home/awsconfig/ + env: + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + TEATIME_FORCE_INTERACTIVE: 0 + AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini + steps: + - name: Checkout composable_kernel repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Checkout TheRock repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + path: "TheRock" + + - name: Runner Health Settings + run: | + df -h + cmake --version + echo "Installed Python versions:" + ls -d /opt/python + echo "python: $(which python), python3: $(which python3)" + echo "Git version: $(git --version)" + git config --global --add safe.directory $PWD + git config fetch.parallel 10 + + - name: Fetch sources + run: | + ./TheRock/build_tools/fetch_sources.py --jobs 12 + + - name: Install python deps + run: | + pip install -r TheRock/requirements.txt + pip freeze + + - name: Configure Projects + env: + amdgpu_families: ${{ env.AMDGPU_FAMILIES }} + package_version: ADHOCBUILD + extra_cmake_options: ${{ inputs.cmake_options }} + BUILD_DIR: build + run: | + python3 TheRock/build_tools/github_actions/build_configure.py + + - name: Build TheRock + run: cmake --build TheRock/build + + - name: Build therock-archives + run: cmake --build TheRock/build --target therock-archives + + - name: Report + if: ${{ !cancelled() }} + run: | + echo "Full SDK du:" + echo "------------" + du -h -d 1 TheRock/build/dist/rocm + echo "Artifact Archives:" + echo "------------------" + ls -lh TheRock/build/artifacts/*.tar.xz + echo "Artifacts:" + echo "----------" + du -h -d 1 TheRock/build/artifacts + + - name: Configure AWS Credentials for non-forked repos + if: ${{ always() && !github.event.pull_request.head.repo.fork }} + uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1 + with: + aws-region: us-east-2 + role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external + + - name: Create Logs index Files and upload logs + if: always() + run: | + python3 TheRock/build_tools/github_actions/create_log_index.py \ + --build-dir=TheRock/build \ + --amdgpu-family=${{ env.AMDGPU_FAMILIES }} + + python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ + --build-dir=TheRock/build \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} + + - name: Upload artifacts + run: | + python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + - name: Add Links to Job Summary + if: always() + run: | + python TheRock/build_tools/github_actions/upload_build_summary.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + therock-test-linux: + name: "Test" + needs: [therock-build-linux] + uses: ./.github/workflows/therock-test-packages.yml + with: + project_to_test: "miopen" + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: "linux" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml new file mode 100644 index 0000000000..3232652b6b --- /dev/null +++ b/.github/workflows/therock-ci.yml @@ -0,0 +1,81 @@ +name: TheRock CI for composable_kernel + +on: + push: + branches: + - develop + workflow_dispatch: + pull_request: + types: + - opened + - synchronize + branches: + - mainline + - release/* + - release-staging/* + - develop + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + setup: + runs-on: ubuntu-24.04 + env: + # The commit being checked out is the merge commit for a PR. Its first + # parent will be the tip of the base branch. + BASE_REF: HEAD^ + outputs: + enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }} + steps: + - name: "Checking out repository" + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + # We need the parent commit to do a diff + fetch-depth: 2 + + - name: "Configuring CI options" + id: configure + run: python .github/scripts/therock_configure_ci.py + + therock-ci-linux: + name: TheRock CI Linux + needs: setup + if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }} + permissions: + contents: read + id-token: write + uses: ./.github/workflows/therock-ci-linux.yml + secrets: inherit + with: + cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + amdgpu_families: "gfx94X-dcgpu" + test_runs_on: "linux-mi325-1gpu-ossci-rocm" + + therock_ci_summary: + name: TheRock CI Summary + if: always() + needs: + - setup + - therock-ci-linux + runs-on: ubuntu-24.04 + steps: + - name: Output failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml new file mode 100644 index 0000000000..37ddd399ad --- /dev/null +++ b/.github/workflows/therock-test-packages.yml @@ -0,0 +1,77 @@ +name: TheRock Test Packages + +on: + workflow_call: + inputs: + project_to_test: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + +permissions: + contents: read + +jobs: + configure_test_matrix: + name: "Configure test matrix" + runs-on: ubuntu-24.04 + if: ${{ inputs.test_runs_on != '' }} + outputs: + components: ${{ steps.configure.outputs.components }} + steps: + - name: "Checking out repository" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: "Configuring CI options" + env: + PLATFORM: ${{ inputs.platform }} + project_to_test: ${{ inputs.project_to_test }} + id: configure + run: python ./build_tools/github_actions/fetch_test_configurations.py + + test_components: + name: 'Test ${{ matrix.components.job_name }}' + runs-on: ${{ inputs.test_runs_on }} + needs: configure_test_matrix + # skip tests if no test matrix to run + if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} + strategy: + fail-fast: false + matrix: + components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ github.run_id }}" + OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build + THEROCK_BIN_DIR: "./build/bin" + steps: + - name: Checkout Repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} + PLATFORM: ${{ inputs.platform }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} + + - name: Test + timeout-minutes: ${{ matrix.components.timeout_minutes }} + run: | + if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi + ${{ matrix.components.test_script }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c942a776d..1246248eac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. * Added support for elementwise kernel. +* Added benchmarking support for tile engine GEMM Multi D. ### Optimized @@ -47,6 +48,7 @@ None * Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. * Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. +* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594) ### Known issues diff --git a/CMakeLists.txt b/CMakeLists.txt index 19c036e1a5..35ebba8085 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -336,6 +336,11 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) + add_compile_options(-mno-wavefrontsize64) + message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Jenkinsfile b/Jenkinsfile index 590ee92e90..b3b63098c2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -401,7 +401,8 @@ def cmake_build(Map conf=[:]){ sh 'ninja -j64 package' archiveArtifacts artifacts: 'composablekernel-dev*.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' - stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" } } else{ @@ -460,7 +461,9 @@ def buildHipClangJob(Map conf=[:]){ } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') @@ -518,7 +521,9 @@ def Build_CK(Map conf=[:]){ } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } if(params.BUILD_LEGACY_OS){ dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' " @@ -567,19 +572,6 @@ def Build_CK(Map conf=[:]){ python3 -m pytest python/test/test_gen_instances.py """ } - dir("build"){ - if (params.RUN_FULL_QA && arch == 2 ){ - // build deb packages - echo "Build packages" - sh 'ninja package' - archiveArtifacts artifacts: 'composablekernel*.deb' - sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' - sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' - sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' - sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' - stash includes: "composablekernel-**.deb", name: "packages" - } - } // run performance tests, stash the logs, results will be processed on the master node dir("script"){ if (params.RUN_PERFORMANCE_TESTS){ @@ -734,7 +726,7 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } - if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + if (params.BUILD_INSTANCES_ONLY){ // unstash deb packages unstash "packages" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" @@ -888,6 +880,10 @@ pipeline { name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, description: "Run the grouped conv large cases tests (default: OFF)") + booleanParam( + name: "RUN_CONV_COMPREHENSIVE_DATASET", + defaultValue: false, + description: "Run comprehensive convolution dataset tests before important changes (default: OFF)") booleanParam( name: "RUN_CODEGEN_TESTS", defaultValue: true, @@ -1086,6 +1082,33 @@ pipeline { } } } + stage("Run Comprehensive Convolution Dataset Tests") + { + parallel + { + stage("Run Comprehensive Dataset Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CONV_COMPREHENSIVE_DATASET.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cd test_data && \ + ./generate_test_dataset.sh && \ + cd ../script && \ + ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 test_grouped_convnd_fwd_dataset_xdl && \ + ./bin/test_grouped_convnd_fwd_dataset_xdl""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run Codegen Tests") { parallel @@ -1172,6 +1195,8 @@ pipeline { -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_MULTI_D_DATATYPE="fp16" \ + -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_fp8_rcr && \ ./bin/benchmark_gemm_fp8_rcr && \ @@ -1188,7 +1213,15 @@ pipeline { ninja -j64 benchmark_gemm_fp8_rrr && \ ./bin/benchmark_gemm_fp8_rrr && \ ninja -j64 benchmark_gemm_fp16_rrr && \ - ./bin/benchmark_gemm_fp16_rrr """ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1210,6 +1243,8 @@ pipeline { -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_MULTI_D_DATATYPE="fp16" \ + -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_fp8_rcr && \ ./bin/benchmark_gemm_fp8_rcr && \ @@ -1226,7 +1261,15 @@ pipeline { ninja -j64 benchmark_gemm_fp8_rrr && \ ./bin/benchmark_gemm_fp8_rrr && \ ninja -j64 benchmark_gemm_fp16_rrr && \ - ./bin/benchmark_gemm_fp16_rrr """ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1385,7 +1428,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") } cleanWs() } @@ -1419,7 +1462,7 @@ pipeline { } agent{ label rocmnode("gfx1101") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx11-generic" \ @@ -1440,7 +1483,7 @@ pipeline { } agent{ label rocmnode("gfx1201") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx12-generic" \ @@ -1462,7 +1505,7 @@ pipeline { stage("Process results"){ when { beforeAgent true - expression { params.RUN_PERFORMANCE_TESTS.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.RUN_PERFORMANCE_TESTS.toBoolean() || params.BUILD_INSTANCES_ONLY.toBoolean()) && !params.BUILD_LEGACY_OS.toBoolean() } } agent { label 'mici' } steps{ diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 6c5d9f9fba..3e018aad1e 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/library/utility/validation_common.hpp" template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) @@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + try + { + ck::utils::validate_gemm_strides_abc( + M, N, K, StrideA, StrideB, StrideC); + } + catch(const std::runtime_error& e) + { + std::cerr << "Error: " << e.what() << std::endl; + return false; + } + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 4adb6f896b..3d8cf32221 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 18731e810e..03c531c1ad 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 87812369bd..5167097b6d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index c3e6ef7d5d..abf7ef3905 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 93034a8b70..2582ea8a11 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index e7c1d6f0be..57e2feb084 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 72109a660b..f72d7afa02 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -7,7 +7,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ make tile_example_fmha_fwd -j ``` This will result in an executable `build/bin/tile_example_fmha_fwd` diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 6fca800c90..42a9d5148a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -115,6 +115,7 @@ PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { @@ -123,6 +124,7 @@ PIPELINE_ENUM_MAP = { "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index ffb6d579ed..0d8f366d8a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -84,6 +84,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, + false, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -98,7 +99,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; #include @@ -109,9 +110,9 @@ float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_b if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ @@ -177,7 +178,7 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; return fmha_batch_prefill_(s, a); }} """ @@ -507,8 +508,8 @@ class KernelComponentFactory: for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: assert False return pipelines diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index bb3a0587e7..0391191fb2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -136,10 +136,10 @@ float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -148,9 +148,9 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -425,10 +425,10 @@ float fmha_bwd_dot_do_o_(const ck_tile::stream_config& s if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -436,9 +436,9 @@ void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_co {{ using k_ = fmha_bwd_dot_do_o_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} @@ -530,10 +530,10 @@ float fmha_bwd_convert_dq_(const ck_tile::stream_confi if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} template <> @@ -542,9 +542,9 @@ void fmha_bwd_convert_dq_oneshot_(const ck_tile::strea {{ using k_ = fmha_bwd_convert_dq_kernel_{F_idx}; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)( ck_tile::stream_config{{s.stream_id_}}); }} diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 95202a5f72..78668729f4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -12,6 +12,7 @@ from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file DTYPE_BITS = { @@ -83,6 +84,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< {F_mode}, fmha_variant_{F_idx}, fmha_mask_{F_idx}, + {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -97,7 +99,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -108,9 +110,9 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ @@ -161,12 +163,19 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; }} """ +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -177,8 +186,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < """ FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ @@ -221,6 +230,7 @@ class FmhaFwdApiTrait: dpad : str dvpad : str skip : str + tr_load : str constraint : CppConstraint @property @@ -231,13 +241,19 @@ class FmhaFwdApiTrait: @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: if self.spad == 't' : return 'true' # always support else : return 'true' elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -248,6 +264,9 @@ class FmhaFwdApiTrait: elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' else: assert False @property @@ -256,7 +275,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -268,7 +287,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr', 'qs']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -290,6 +309,7 @@ class FmhaFwdPipeline: F_squant : str # F_mask : str # value from MASK_MAP F_skip : str # true/false + F_trload : str # true/false F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property @@ -331,6 +351,9 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @@ -351,31 +374,39 @@ class FmhaFwdApiPool: @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][(hdim, hdim_v)] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_constraint=trait.constraint, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) @dataclass class FmhaFwdTileSize: @@ -458,7 +489,8 @@ class FmhaFwdKernel: F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) @property def name(self) -> str: @@ -494,6 +526,7 @@ class FmhaFwdKernel: dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) class KernelComponentFactory: @@ -503,11 +536,16 @@ class KernelComponentFactory: def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], @@ -534,37 +572,30 @@ class KernelComponentFactory: if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim - # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', 'f', mask, 'f', 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -602,6 +633,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -668,10 +705,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 2e5bc2bd3d..0ebeaddf9c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -60,9 +60,9 @@ float fmha_fwd_appendkv_(const ck_tile::stream_config& s, fmha_fw if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 0e4ac44d45..1dd8f0e3c6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -108,9 +108,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} @@ -208,9 +208,9 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ using k_ = fmha_kernel; auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); + ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; }} @@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -657,7 +657,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d '64' : FmhaFwdSplitKVCombineTileSize(32, -1), '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index a98d1d4423..43a69cca6c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -109,9 +109,9 @@ float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd if(s.log_level_ > 0) std::cout << ", " << k_::GetName() << std::flush; auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); - constexpr dim3 blocks = k_::BlockSize(); + const dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 9c2907778f..9f1e0f6948 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -809,20 +809,6 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; - - printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " - "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", - fmha_traits.hdim_q, - fmha_traits.hdim_v, - fmha_traits.data_type.c_str(), - fmha_traits.is_group_mode, - static_cast(fmha_traits.mask_type), - static_cast(fmha_traits.bias_type), - fmha_traits.has_dbias, - fmha_traits.has_dropout, - fmha_traits.is_store_randval, - fmha_traits.is_deterministic); - fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index ee599c973b..777ae59db3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1144,7 +1144,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + << " GB/s" << std::flush << std::endl; if(do_validation == 0) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index bd5e110214..8c712b0aa7 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -1028,6 +1029,7 @@ template struct fmha_fwd_traits_ { @@ -1052,6 +1054,7 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 599c595a75..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,14 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -for perm in 0 1 ; do - -$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 - -done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index b5e6778aa5..e7babd2744 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -9,6 +9,8 @@ # host name : $hostname # gpu architecture: e.g., gfx90a, or gfx942, etc. +set -euo pipefail + #get the command line arguments: export env_type=$1 echo 'Environment type: ' $env_type diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 5ba3425e26..d123f842a2 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -1,5 +1,7 @@ -#!/bin/sh +#!/bin/bash # TODO: run this script from CK root or build directory +set -euo pipefail + EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" KNAME=1 @@ -17,12 +19,12 @@ for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index b867cd6c07..3913a0d5c2 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -1,5 +1,7 @@ #!/bin/bash # TODO: run this script from CK root or build directory +set -euo pipefail + EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" KNAME=1 @@ -42,7 +44,6 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do - for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -51,20 +52,19 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done - done ; } run_fp8_tests() { diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 817f62dae7..da74e2e3c1 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -42,7 +42,7 @@ return hidden_states, per_token_scale ``` # in the root of ck_tile mkdir build && cd build -sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... make tile_example_layernorm2d_fwd -j ``` This will result in an executable `build/bin/tile_example_layernorm2d_fwd` diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index d77582630a..c4366f6662 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -235,7 +235,7 @@ float layernorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Layernorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); @@ -243,7 +243,7 @@ float layernorm2d_fwd_(const S& s, A a) std::cout << ", " << Kernel::GetName() << std::flush; return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{{}}, grids, blocks, 0, kargs)); }} """ diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index b1aede42c7..825cd6e522 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,6 +1,7 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) +add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -14,3 +15,4 @@ list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-rev target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 20cc202176..6358b76fd9 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -7,18 +7,19 @@ This folder contains example for GEMM using ck_tile tile-programming implementat # in the root of ck_tile mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ +../script/cmake-ck-dev.sh ../ # The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation make tile_example_gemm_universal -j +# The weight preshuffle pipeline on the gemm calculation +make tile_example_gemm_weight_preshuffle -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` ## example ``` args: - -b batch size (default:1) -m m dimension (default:1024) -n n dimension (default:2048) -k k dimension (default:64) @@ -29,9 +30,11 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -e Absolute error tolerance (default:1e-5) -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k splitK value (default:1) + -init 0:random, 1:linear, 2:constant (default:1) + -persistent 0:non-persistent, 1:persistent (default:0) ``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 0d9c2d9957..8cdbe39e86 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,15 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include -#include -#include -#include -#include - -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" template ; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using CodegenGemmTraits = ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; @@ -76,7 +65,6 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, M_Warp, @@ -92,8 +80,8 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -112,27 +100,27 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) } float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); return ave_time; }; if(args.k_batch == 1) { - return Run(ck_tile::integral_constant{}); + return Run(MemoryOpSet{}); } else { - return Run(ck_tile::integral_constant{}); + return Run(MemoryOpAtomicAdd{}); } } #include "run_gemm_example.inc" template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -142,12 +130,12 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -160,22 +148,22 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a if(a_layout == "R" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "R" && b_layout == "R") { return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + arg_parser, Col{}, Col{}, Row{}); } else { @@ -184,38 +172,34 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } -int run_gemm_example(int argc, char* argv[]) +int run_gemm_example(ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); if(data_type == "fp16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "i8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "pk_int4_t") { @@ -223,7 +207,7 @@ int run_gemm_example(int argc, char* argv[]) if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else { @@ -238,9 +222,13 @@ int run_gemm_example(int argc, char* argv[]) int main(int argc, char* argv[]) { + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + try { - return !run_gemm_example(argc, argv); + return !run_gemm_example(arg_parser); } catch(const std::runtime_error& e) { diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp new file mode 100644 index 0000000000..f42135a0b5 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -0,0 +1,1006 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +/** + * @brief Tile partitioner with output offset support. + * + * This partitioner extends the spatially local tile partitioner to support + * split-K reduction by providing workspace output offset calculation. Each K-split + * writes to a separate slice of the workspace: workspace[k_id * M * N]. + */ +template +struct GemmSplitKTilePartitioner + : public ck_tile::GemmSpatiallyLocalTilePartitioner +{ + using Base = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Inherit constructors and methods + using Base::Base; + using Base::GetLoopNum; + + /** + * @brief Calculate output pointer offset for split-K reduction. + * + * @param kargs Kernel arguments. + * @param k_id Current K-split ID (from blockIdx.z or calculated k_batch). + * @return ck_tile::index_t The offset for this K-split. + */ + template + CK_TILE_HOST_DEVICE static ck_tile::index_t GetOutputOffset(const KernelArgs& kargs, + ck_tile::index_t k_id) noexcept + { + // Each K-split gets its own M*N workspace slice + return (kargs.k_batch > 1) ? (k_id * kargs.M * kargs.N) : 0; + } +}; + +/** + * @brief Extended GEMM host arguments for two-stage split-K implementation + * + * This structure supports the two-stage split-K approach where: + * 1. Stage 1: GEMM writes partial results to workspace memory + * 2. Stage 2: Reduction kernel sums workspace results to final output + * + * The base class e_ptr points to workspace, while final_output_ptr points to the actual output + */ +struct GemmSplitKHostArgs : public ck_tile::GemmHostArgs +{ + using BaseArgs = ck_tile::GemmHostArgs; + + CK_TILE_HOST GemmSplitKHostArgs() = default; + CK_TILE_HOST GemmSplitKHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* workspace_ptr_, // Workspace for partial results + void* e_ptr_, // Final output destination + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t workspace_stride_, + ck_tile::index_t stride_E_) + : BaseArgs(a_ptr_, + b_ptr_, + workspace_ptr_, // Base e_ptr = workspace_ptr + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + workspace_stride_), + final_output_ptr(e_ptr_), + final_stride_E(stride_E_) + { + } + + void* final_output_ptr; // Pointer to final output tensor + ck_tile::index_t final_stride_E; // Stride for final output tensor +}; + +/** + * @brief Stage 1: GEMM kernel that writes partial split-K results to workspace + * + * This function performs the matrix multiplication with split-K, where each + * K-split writes its partial result to a separate section of the workspace. + * + * Workspace layout: [k_batch, M, N] where each [M, N] slice contains + * partial results for one K-split. + * + * @param args Extended arguments containing workspace and final output pointers + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = GemmSplitKTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + // Create base GEMM arguments pointing to workspace instead of final output + // The workspace will store partial results from each K-split + ck_tile::GemmHostArgs base_args(args.a_ptr, + args.b_ptr, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_E); + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + // For workspace mode, always use SET operation since each K-split writes to separate memory + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +/** + * @brief Stage 2: Reduction kernel that sums partial split-K results to final output + * + * This function reduces the partial results stored in workspace memory by stage 1. + * It sums across the k_batch dimension to produce the final GEMM result. + * + * Workspace layout: [k_batch, M, N] -> Final output: [M, N] + * + * @tparam CDataType Output data type + * @tparam ComputeDataType Computation precision for reduction + * @tparam ELayout Memory layout of output tensor + * @param args Extended arguments containing workspace and output information + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + const ck_tile::index_t reduce_dim_size = args.k_batch; // Number of partial results to reduce + // Calculate output size based on the final output tensor dimensions + const ck_tile::index_t output_size = args.M * args.N; + + // Workspace layout: [k_batch, M, N] where each [M, N] slice has the same layout as final output + // The workspace strides need to account for the layout of the final output tensor + auto workspace_shape = ck_tile::make_tuple(args.k_batch, args.M, args.N); + auto workspace_strides = + ck_tile::make_tuple(args.M * args.N, // k_batch stride: jump to next K split + args.final_stride_E, // stride same as final output stride + 1); + + // Define kept and reduced dimensions + constexpr auto kept_dim = ck_tile::sequence<1, 2>{}; // Keep M, N dimensions + constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension + + using ReduceOp = ck_tile::ReduceOp::Add; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + + using Shape = ck_tile::Reduce2dShape; + using Problem = + ck_tile::Reduce2dProblem; + using Kernel = ck_tile::Reduce; + + if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides)) + { + throw std::runtime_error("Wrong! Reduction arguments not supported!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 2 - Launching Reduction kernel" << '\n' + << "workspace shape: [" << args.k_batch << ", " << args.M << ", " << args.N << "]" + << '\n' + << "output shape: [" << args.M << ", " << args.N << "]" << '\n' + << "grid size: " << kGridSize << std::endl; + } + + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // LDS size + static_cast(args.e_ptr), // workspace input + static_cast(args.final_output_ptr), // final output + workspace_shape, + workspace_strides, + kept_dim, + reduce_dims)); + + return ave_time; +} + +/** + * @brief Orchestrator for two-stage split-K GEMM implementation + * + * This function coordinates the two-stage approach: + * 1. Stage 1: Execute GEMM with each K-split writing to workspace + * 2. Stage 2: Reduce workspace results to final output (if k_batch > 1) + * + * @param args Extended arguments for two-stage execution + * @param s Stream configuration + * @return Total execution time (GEMM + Reduction) + */ +template +float gemm_splitk_two_stage(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + float gemm_time = 0.0f; + float reduce_time = 0.0f; + + if(s.log_level_ > 0) + { + std::cout << "Starting Two-Stage GEMM+SplitK with k_batch=" << args.k_batch << std::endl; + std::cout << "Workspace size: " << args.k_batch << " x " << args.M << " x " << args.N + << " = " << args.k_batch * args.M * args.N * sizeof(CDataType) << " bytes" + << std::endl; + } + + // Stage 1: GEMM to workspace + gemm_time = gemm_stage1(args, s); + + // Synchronize before stage 2 + auto sync_result = hipStreamSynchronize(s.stream_id_); + if(sync_result != hipSuccess) + { + throw std::runtime_error("Stream synchronization failed"); + } + + // Stage 2: Reduction from workspace to final output (if needed) + if(args.k_batch > 1) + { + // Use appropriate precision for reduction computations + using ComputeDataType = std::conditional_t< + std::is_same_v, + float, + std::conditional_t, float, CDataType>>; + reduce_time = reduce_stage2(args, s); + } + else + { + // Single K-split: simple copy from workspace to final output + auto copy_result = hipMemcpyAsync(args.final_output_ptr, + args.e_ptr, + args.M * args.N * sizeof(CDataType), + hipMemcpyDeviceToDevice, + s.stream_id_); + if(copy_result != hipSuccess) + { + throw std::runtime_error("Memory copy failed"); + } + } + + if(s.log_level_ > 0) + { + std::cout << "GEMM stage time: " << gemm_time << " ms" << std::endl; + if(args.k_batch > 1) + { + std::cout << "Reduction stage time: " << reduce_time << " ms" << std::endl; + } + std::cout << "Total time: " << gemm_time + reduce_time << " ms" << std::endl; + } + + return gemm_time + reduce_time; +} + +/** + * @brief High-level interface for two-stage split-K GEMM execution + * + * @param a_m_k_dev_buf Input matrix A device buffer + * @param b_k_n_dev_buf Input matrix B device buffer + * @param c_m_n_dev_buf Output matrix C device buffer + * @param M Matrix M dimension + * @param N Matrix N dimension + * @param K Matrix K dimension + * @param stride_A Memory stride for matrix A + * @param stride_B Memory stride for matrix B + * @param stride_C Memory stride for matrix C + * @param kbatch Number of K-splits for split-K execution + * @param n_warmup Number of warmup iterations + * @param n_repeat Number of repeat iterations for benchmarking + * @param persistent Whether to use persistent kernel execution + * @return Average execution time in milliseconds + */ +template +float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat, + bool persistent) +{ + // Calculate workspace size: kbatch * M * N elements + const ck_tile::index_t workspace_size = kbatch * M * N * sizeof(CDataType); + const ck_tile::index_t workspace_stride = stride_C; // Stride for k_batch dimension + + // Allocate workspace memory + ck_tile::DeviceMem workspace_buf(workspace_size); + workspace_buf.SetZero(); + + // Create extended args for two-stage approach + GemmSplitKHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + workspace_buf.GetDeviceBuffer(), // workspace_ptr (used as e_ptr for stage 1) + c_m_n_dev_buf.GetDeviceBuffer(), // final_output_ptr + kbatch, // k_batch + M, + N, + K, // dimensions + stride_A, + stride_B, // input strides + workspace_stride, // workspace stride + stride_C // final output stride + }; + + float ave_time; + ck_tile::stream_config config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}; + + if(persistent) + { + ave_time = gemm_splitk_two_stage(args, config); + } + else + { + ave_time = gemm_splitk_two_stage(args, config); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Two-Stage GEMM+SplitK with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes" + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; + + return ave_time; +} + +// Two-stage implementation of run_gemm_example_with_layouts +template +int run_gemm_example_with_layouts_two_stage(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); + + const bool preshuffle = GemmConfig::Preshuffle; + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + if constexpr(preshuffle) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(!preshuffle && GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + static_assert(!GemmConfig::PermuteA, "Not implemented"); + + if constexpr(preshuffle) + { + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + // shuffled buffer B for device implementation + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } + else + { + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + std::cout << "Using Workspace Split-K Mode (Two-Stage with Reduction)" << std::endl; + // Use the new two-stage approach + invoke_gemm_splitk_two_stage, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + if constexpr(GemmConfig::Preshuffle) + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + // memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && std::is_same_v) + { + throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); + } + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + // Use new two-stage approach for both int4 and other data types + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Row{}, Row{}); + } + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + return 0; +} + +template