diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 15903314f9..bd597344ea 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml new file mode 100644 index 0000000000..645a91c030 --- /dev/null +++ b/.github/workflows/therock-ci-linux.yml @@ -0,0 +1,128 @@ +name: TheRock CI Linux + +on: + workflow_call: + inputs: + cmake_options: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + +permissions: + contents: read + +jobs: + therock-build-linux: + name: Build Linux Packages + runs-on: azure-linux-scale-rocm + permissions: + id-token: write + container: + image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292 + env: + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + TEATIME_FORCE_INTERACTIVE: 0 + steps: + - name: Checkout composable_kernel repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Checkout TheRock repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + path: "TheRock" + + - name: Runner Health Settings + run: | + df -h + cmake --version + echo "Installed Python versions:" + ls -d /opt/python + echo "python: $(which python), python3: $(which python3)" + echo "Git version: $(git --version)" + git config --global --add safe.directory $PWD + git config fetch.parallel 10 + + - name: Fetch sources + run: | + ./TheRock/build_tools/fetch_sources.py --jobs 12 + + - name: Install python deps + run: | + pip install -r TheRock/requirements.txt + pip freeze + + - name: Configure Projects + env: + amdgpu_families: ${{ env.AMDGPU_FAMILIES }} + package_version: ADHOCBUILD + extra_cmake_options: ${{ inputs.cmake_options }} + BUILD_DIR: build + run: | + python3 TheRock/build_tools/github_actions/build_configure.py + + - name: Build TheRock + run: cmake --build TheRock/build + + - name: Build therock-archives + run: cmake --build TheRock/build --target therock-archives + + - name: Report + if: ${{ !cancelled() }} + run: | + echo "Full SDK du:" + echo "------------" + du -h -d 1 TheRock/build/dist/rocm + echo "Artifact Archives:" + echo "------------------" + ls -lh TheRock/build/artifacts/*.tar.xz + echo "Artifacts:" + echo "----------" + du -h -d 1 TheRock/build/artifacts + + - name: Configure AWS Credentials + if: always() + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + aws-region: us-east-2 + role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external + + - name: Create Logs index Files and upload logs + if: always() + run: | + python3 TheRock/build_tools/github_actions/create_log_index.py \ + --build-dir=TheRock/build \ + --amdgpu-family=${{ env.AMDGPU_FAMILIES }} + + python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ + --build-dir=TheRock/build \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} + + - name: Upload artifacts + run: | + python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + - name: Add Links to Job Summary + if: always() + run: | + python TheRock/build_tools/github_actions/upload_build_summary.py \ + --run-id ${{ github.run_id }} \ + --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ + --build-dir TheRock/build + + therock-test-linux: + name: "Test" + needs: [therock-build-linux] + uses: ./.github/workflows/therock-test-packages.yml + with: + project_to_test: "miopen" + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: "linux" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml new file mode 100644 index 0000000000..18411baa09 --- /dev/null +++ b/.github/workflows/therock-ci.yml @@ -0,0 +1,50 @@ +name: TheRock CI for composable_kernel + +on: + push: + branches: + - develop + workflow_dispatch: + +permissions: + contents: read + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + therock-ci-linux: + name: TheRock CI Linux + permissions: + contents: read + id-token: write + uses: ./.github/workflows/therock-ci-linux.yml + secrets: inherit + with: + cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + amdgpu_families: "gfx94X-dcgpu" + test_runs_on: "linux-mi325-1gpu-ossci-rocm" + + therock_ci_summary: + name: TheRock CI Summary + if: always() + needs: + - therock-ci-linux + runs-on: ubuntu-24.04 + steps: + - name: Output failed jobs + run: | + echo '${{ toJson(needs) }}' + FAILED_JOBS="$(echo '${{ toJson(needs) }}' \ + | jq --raw-output \ + 'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \ + )" + if [[ "${FAILED_JOBS}" != "" ]]; then + echo "The following jobs failed: ${FAILED_JOBS}" + exit 1 + fi diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml new file mode 100644 index 0000000000..439135743c --- /dev/null +++ b/.github/workflows/therock-test-packages.yml @@ -0,0 +1,76 @@ +name: TheRock Test Packages + +on: + workflow_call: + inputs: + project_to_test: + type: string + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + +permissions: + contents: read + +jobs: + configure_test_matrix: + name: "Configure test matrix" + runs-on: ubuntu-24.04 + if: ${{ inputs.test_runs_on != '' }} + outputs: + components: ${{ steps.configure.outputs.components }} + steps: + - name: "Checking out repository" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: "Configuring CI options" + env: + PLATFORM: ${{ inputs.platform }} + project_to_test: ${{ inputs.project_to_test }} + id: configure + run: python ./build_tools/github_actions/fetch_test_configurations.py + + test_components: + name: 'Test ${{ matrix.components.job_name }}' + runs-on: ${{ inputs.test_runs_on }} + needs: configure_test_matrix + # skip tests if no test matrix to run + if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} + strategy: + fail-fast: false + matrix: + components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ github.run_id }}" + OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build + THEROCK_BIN_DIR: "./build/bin" + steps: + - name: Checkout Repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} + PLATFORM: ${{ inputs.platform }} + + - name: Test + timeout-minutes: ${{ matrix.components.timeout_minutes }} + run: | + if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi + ${{ matrix.components.test_script }} diff --git a/.gitignore b/.gitignore index f4d5ff7abd..e4dd8f7513 100644 --- a/.gitignore +++ b/.gitignore @@ -55,6 +55,8 @@ _static/ _templates/ _toc.yml _doxygen/ +docs/doxygen/html +docs/doxygen/xml # JetBrains IDE .idea/ @@ -66,3 +68,6 @@ build*/ # Python cache __pycache__/ + +.cache/ + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml old mode 100755 new mode 100644 index d6700ae05b..664c5219e2 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: clang-format name: clang-format - entry: clang-format-12 -i --style=file + entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - id: copyright-year-checker @@ -12,3 +12,27 @@ repos: verbose: false language: script types: [c++] + - id: remove-exec-bit + name: Remove executable bit from non-executable files + entry: script/remove_exec_bit.sh + language: script + types_or: [c++, text] + verbose: true + - id: ruff-check + name: Ruff Linter + entry: ruff check --fix + language: python + types: [python] + additional_dependencies: [ruff] + - id: ruff-format + name: Ruff Formatter + entry: ruff format + language: python + types: [python] + additional_dependencies: [ruff] + - id: run-remod-if-ck-tile-changed + name: Run remod.py if ck_tile files changed + entry: script/remod_for_ck_tile.sh + language: script + always_run: true + pass_filenames: false diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d07abfc24..7c09271edc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,16 +2,39 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## Composable Kernel 1.1.0 for ROCm 6.5.0 +## Composable Kernel 1.1.0 for ROCm 7.0.0 ### Added +* Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). +* Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). +* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). +* Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added support for Multiple D GEMM +* Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types +* Added support for FP16 2:4 structured sparsity to universal GEMM. +* Added support for Split K for grouped convolution backward data. +* Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) +* Added benchmarking support for tile engine GEMM. +* Added Ping-pong scheduler support for GEMM operation along the K dimension. +* Added rotating buffer feature for CK_Tile GEMM. +* Added int8 support for CK_TILE GEMM. +* Added support for elementwise kernel. +* Added benchmarking support for tile engine GEMM Multi D. ### Optimized -None + +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) +* Added Vectorize Transpose optimization for CK Tile (#2131) +* Added the asynchronous copy for gfx950 (#2425) + ### Fixes @@ -22,11 +45,18 @@ None * Removed support for gfx940 and gfx941 targets (#1944) * Replaced the raw buffer load/store intrinsics with Clang20 built-ins (#1876) * DL and DPP kernels are now enabled by default. +* Number of instances in instance factory for grouped convolution forward NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward weight NGCHW/GKYXC/NGKHW has been reduced. +* Number of instances in instance factory for grouped convolution backward data NGCHW/GKYXC/NGKHW has been reduced. ### Known issues None +### Upcoming changes + +* Non-grouped convolutions are deprecated. All of their functionality is supported by grouped convolution. + ## Composable Kernel 1.1.0 for ROCm 6.1.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index bb0c254e06..07d2e166bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,17 +26,21 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX HIP) include(CTest) +option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) +option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) + # Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8" # CK Codegen requires dataclass which is added in Python 3.7 # Python version 3.8 is required for general good practice as it is default for Ubuntu 20.04 if(NOT CK_USE_ALTERNATIVE_PYTHON) find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) else() - message("Using alternative python version") + message(STATUS "Using alternative python version") set(EXTRA_PYTHON_PATH) # this is overly restrictive, we may need to be more flexible on the following string(REPLACE "/bin/python3.8" "" EXTRA_PYTHON_PATH "${CK_USE_ALTERNATIVE_PYTHON}") - message("alternative python path is: ${EXTRA_PYTHON_PATH}") + message(STATUS "alternative python path is: ${EXTRA_PYTHON_PATH}") find_package(Python3 3.6 COMPONENTS Interpreter REQUIRED) add_definitions(-DPython3_EXECUTABLE="${CK_USE_ALTERNATIVE_PYTHON}") set(Python3_EXECUTABLE "${CK_USE_ALTERNATIVE_PYTHON}") @@ -76,7 +80,7 @@ if (DTYPES) add_definitions(-DCK_ENABLE_BF16) set(CK_ENABLE_BF16 "ON") endif() - message("DTYPES macro set to ${DTYPES}") + message(STATUS "DTYPES macro set to ${DTYPES}") else() add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8) set(CK_ENABLE_INT8 "ON") @@ -94,6 +98,15 @@ add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) add_compile_options(-Wno-unique-object-duplication) +# add -Og -gdwarf64 for debug builds +add_compile_options( + "$<$:-Og>" + "$<$:-gdwarf64>" +) + +# Recent change in compiler makes this warning ON by default, which led to compile errors. +add_compile_options(-Wno-nrvo) + if(NOT DISABLE_DL_KERNELS) add_definitions(-DDL_KERNELS) set(DL_KERNELS "ON") @@ -139,8 +152,8 @@ rocm_setup_version(VERSION ${version}) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}") -message("GPU_TARGETS= ${GPU_TARGETS}") -message("GPU_ARCHS= ${GPU_ARCHS}") +message(STATUS "GPU_TARGETS= ${GPU_TARGETS}") +message(STATUS "GPU_ARCHS= ${GPU_ARCHS}") if(GPU_ARCHS) #disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package unset(GPU_TARGETS CACHE) @@ -155,9 +168,9 @@ find_package(hip REQUIRED) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}") -message("hip_version_flat=${hip_VERSION_FLAT}") +message(STATUS "hip_version_flat=${hip_VERSION_FLAT}") -message("checking which targets are supported") +message(STATUS "checking which targets are supported") #In order to build just the CK library (without tests and examples) for all supported GPU targets #use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" #the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. @@ -167,8 +180,12 @@ if(NOT ENABLE_ASAN_PACKAGING) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") - else() + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600300000 AND ${hip_VERSION_FLAT} LESS 600400000) set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600400000 AND ${hip_VERSION_FLAT} LESS 600443483) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx950") + elseif(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER_EQUAL 600443483) + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx950;gfx10-3-generic;gfx11-generic;gfx12-generic") endif() else() #build CK only for xnack-supported targets when using ASAN @@ -192,23 +209,28 @@ endif() rocm_check_target_ids(SUPPORTED_GPU_TARGETS TARGETS ${CK_GPU_TARGETS}) -message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") +message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") - message("Enabling XDL instances") + message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") - message("Enabling FP8 gemms on native architectures") + message(STATUS "Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") - message("Enabling WMMA instances") + message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message(STATUS "Enabling WMMA FP8 gemms on native architectures") + add_definitions(-DCK_USE_WMMA_FP8) + set(CK_USE_WMMA_FP8 "ON") +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") @@ -220,6 +242,8 @@ endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) set(CK_USE_NATIVE_MX_SUPPORT "ON") + add_definitions(-DCK_GFX950_SUPPORT) + set(CK_GFX950_SUPPORT "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) @@ -234,32 +258,32 @@ configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/con if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500723302) check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) - message("Adding the fno-offload-uniform-block compiler flag") + message(STATUS "Adding the fno-offload-uniform-block compiler flag") add_compile_options(-fno-offload-uniform-block) endif() endif() if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 500500000) check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION) if(HAS_LSR_DROP_SOLUTION) - message("Adding the lsr-drop-solution=1 compiler flag") + message(STATUS "Adding the lsr-drop-solution=1 compiler flag") add_compile_options("SHELL: -mllvm --lsr-drop-solution=1") endif() endif() if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) if(HAS_ENABLE_POST_MISCHED) - message("Adding the enable-post-misched=0 compiler flag") + message(STATUS "Adding the enable-post-misched=0 compiler flag") add_compile_options("SHELL: -mllvm -enable-post-misched=0") endif() endif() set(check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) - message("Adding the amdgpu-coerce-illegal-types=1") + message(STATUS "Adding the amdgpu-coerce-illegal-types=1") add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") endif() if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) - message("Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false") + message(STATUS "Adding -amdgpu-early-inline-all=true and -amdgpu-function-calls=false") add_compile_options("SHELL: -mllvm -amdgpu-early-inline-all=true") add_compile_options("SHELL: -mllvm -amdgpu-function-calls=false") endif() @@ -292,17 +316,31 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) +option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) add_compile_options(-Wno-bit-int-extension) - message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") + message(STATUS "CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() if(USE_OPT_GFX11) add_compile_options(-mcumode) add_compile_options(-mno-wavefrontsize64) - message("CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") + add_compile_definitions(CK_TILE_WAVE32_ENABLED) + message(STATUS "CK compiled with USE_OPT_GFX11 set to ${USE_OPT_GFX11}") +endif() + +if(ENABLE_ASM_DUMP) + add_compile_options(--save-temps) + add_compile_options(-Wno-gnu-line-marker) + message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") +endif() + +if(USE_OPT_GFX12 AND (SUPPORTED_GPU_TARGETS MATCHES "gfx12")) + add_compile_options(-mno-wavefrontsize64) + add_compile_definitions(CK_TILE_WAVE32_ENABLED) + message(STATUS "CK compiled with USE_OPT_GFX12 set to ${USE_OPT_GFX12}") endif() ## Threads @@ -311,10 +349,10 @@ find_package(Threads REQUIRED) link_libraries(Threads::Threads) ## C++ -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -message("CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") +message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") # https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_macros.html # _GLIBCXX_ASSERTIONS @@ -330,7 +368,7 @@ endif() set(CMAKE_HIP_PLATFORM amd) set(CMAKE_HIP_COMPILER ${CMAKE_CXX_COMPILER}) set(CMAKE_HIP_EXTENSIONS ON) -message("CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") +message(STATUS "CMAKE_HIP_COMPILER: ${CMAKE_HIP_COMPILER}") ## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") @@ -345,10 +383,10 @@ else() find_package(OpenMP REQUIRED) endif() -message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") -message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") -message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") -message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") +message(STATUS "OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message(STATUS "OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message(STATUS "OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message(STATUS "OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY}) @@ -380,146 +418,152 @@ else() add_compile_definitions(__HIP_PLATFORM_HCC__=1) endif() -## tidy include(EnableCompilerWarnings) +## tidy set(CK_TIDY_ERRORS ERRORS * -readability-inconsistent-declaration-parameter-name) if(CMAKE_CXX_COMPILER MATCHES ".*hcc" OR CMAKE_CXX_COMPILER MATCHES ".*clang\\+\\+") - set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) +set(CK_TIDY_CHECKS -modernize-use-override -readability-non-const-parameter) # Enable tidy on hip elseif(CK_BACKEND STREQUAL "HIP" OR CK_BACKEND STREQUAL "HIPNOGPU") - set(CK_TIDY_ERRORS ALL) +set(CK_TIDY_ERRORS ALL) endif() -include(ClangTidy) -enable_clang_tidy( - CHECKS - * - -abseil-* - -android-cloexec-fopen - # Yea we shouldn't be using rand() - -cert-msc30-c - -bugprone-exception-escape - -bugprone-macro-parentheses - -cert-env33-c - -cert-msc32-c - -cert-msc50-cpp - -cert-msc51-cpp - -cert-dcl37-c - -cert-dcl51-cpp - -clang-analyzer-alpha.core.CastToStruct - -clang-analyzer-optin.performance.Padding - -clang-diagnostic-deprecated-declarations - -clang-diagnostic-extern-c-compat - -clang-diagnostic-unused-command-line-argument - -cppcoreguidelines-avoid-c-arrays - -cppcoreguidelines-avoid-magic-numbers - -cppcoreguidelines-explicit-virtual-functions - -cppcoreguidelines-init-variables - -cppcoreguidelines-macro-usage - -cppcoreguidelines-non-private-member-variables-in-classes - -cppcoreguidelines-pro-bounds-array-to-pointer-decay - -cppcoreguidelines-pro-bounds-constant-array-index - -cppcoreguidelines-pro-bounds-pointer-arithmetic - -cppcoreguidelines-pro-type-member-init - -cppcoreguidelines-pro-type-reinterpret-cast - -cppcoreguidelines-pro-type-union-access - -cppcoreguidelines-pro-type-vararg - -cppcoreguidelines-special-member-functions - -fuchsia-* - -google-explicit-constructor - -google-readability-braces-around-statements - -google-readability-todo - -google-runtime-int - -google-runtime-references - -hicpp-vararg - -hicpp-braces-around-statements - -hicpp-explicit-conversions - -hicpp-named-parameter - -hicpp-no-array-decay - # We really shouldn't use bitwise operators with signed integers, but - # opencl leaves us no choice - -hicpp-avoid-c-arrays - -hicpp-signed-bitwise - -hicpp-special-member-functions - -hicpp-uppercase-literal-suffix - -hicpp-use-auto - -hicpp-use-equals-default - -hicpp-use-override - -llvm-header-guard - -llvm-include-order - #-llvmlibc-* - -llvmlibc-restrict-system-libc-headers - -llvmlibc-callee-namespace - -llvmlibc-implementation-in-namespace - -llvm-else-after-return - -llvm-qualified-auto - -misc-misplaced-const - -misc-non-private-member-variables-in-classes - -misc-no-recursion - -modernize-avoid-bind - -modernize-avoid-c-arrays - -modernize-pass-by-value - -modernize-use-auto - -modernize-use-default-member-init - -modernize-use-equals-default - -modernize-use-trailing-return-type - -modernize-use-transparent-functors - -performance-unnecessary-value-param - -readability-braces-around-statements - -readability-else-after-return - # we are not ready to use it, but very useful - -readability-function-cognitive-complexity - -readability-isolate-declaration - -readability-magic-numbers - -readability-named-parameter - -readability-uppercase-literal-suffix - -readability-convert-member-functions-to-static - -readability-qualified-auto - -readability-redundant-string-init - # too many narrowing conversions in our code - -bugprone-narrowing-conversions - -cppcoreguidelines-narrowing-conversions - -altera-struct-pack-align - -cppcoreguidelines-prefer-member-initializer - ${CK_TIDY_CHECKS} - ${CK_TIDY_ERRORS} - HEADER_FILTER - "\.hpp$" - EXTRA_ARGS - -DCK_USE_CLANG_TIDY -) +if(ENABLE_CLANG_CPP_CHECKS) + include(ClangTidy) + enable_clang_tidy( + CHECKS + * + -abseil-* + -android-cloexec-fopen + # Yea we shouldn't be using rand() + -cert-msc30-c + -bugprone-exception-escape + -bugprone-macro-parentheses + -cert-env33-c + -cert-msc32-c + -cert-msc50-cpp + -cert-msc51-cpp + -cert-dcl37-c + -cert-dcl51-cpp + -clang-analyzer-alpha.core.CastToStruct + -clang-analyzer-optin.performance.Padding + -clang-diagnostic-deprecated-declarations + -clang-diagnostic-extern-c-compat + -clang-diagnostic-unused-command-line-argument + -cppcoreguidelines-avoid-c-arrays + -cppcoreguidelines-avoid-magic-numbers + -cppcoreguidelines-explicit-virtual-functions + -cppcoreguidelines-init-variables + -cppcoreguidelines-macro-usage + -cppcoreguidelines-non-private-member-variables-in-classes + -cppcoreguidelines-pro-bounds-array-to-pointer-decay + -cppcoreguidelines-pro-bounds-constant-array-index + -cppcoreguidelines-pro-bounds-pointer-arithmetic + -cppcoreguidelines-pro-type-member-init + -cppcoreguidelines-pro-type-reinterpret-cast + -cppcoreguidelines-pro-type-union-access + -cppcoreguidelines-pro-type-vararg + -cppcoreguidelines-special-member-functions + -fuchsia-* + -google-explicit-constructor + -google-readability-braces-around-statements + -google-readability-todo + -google-runtime-int + -google-runtime-references + -hicpp-vararg + -hicpp-braces-around-statements + -hicpp-explicit-conversions + -hicpp-named-parameter + -hicpp-no-array-decay + # We really shouldn't use bitwise operators with signed integers, but + # opencl leaves us no choice + -hicpp-avoid-c-arrays + -hicpp-signed-bitwise + -hicpp-special-member-functions + -hicpp-uppercase-literal-suffix + -hicpp-use-auto + -hicpp-use-equals-default + -hicpp-use-override + -llvm-header-guard + -llvm-include-order + #-llvmlibc-* + -llvmlibc-restrict-system-libc-headers + -llvmlibc-callee-namespace + -llvmlibc-implementation-in-namespace + -llvm-else-after-return + -llvm-qualified-auto + -misc-misplaced-const + -misc-non-private-member-variables-in-classes + -misc-no-recursion + -modernize-avoid-bind + -modernize-avoid-c-arrays + -modernize-pass-by-value + -modernize-use-auto + -modernize-use-default-member-init + -modernize-use-equals-default + -modernize-use-trailing-return-type + -modernize-use-transparent-functors + -performance-unnecessary-value-param + -readability-braces-around-statements + -readability-else-after-return + # we are not ready to use it, but very useful + -readability-function-cognitive-complexity + -readability-isolate-declaration + -readability-magic-numbers + -readability-named-parameter + -readability-uppercase-literal-suffix + -readability-convert-member-functions-to-static + -readability-qualified-auto + -readability-redundant-string-init + # too many narrowing conversions in our code + -bugprone-narrowing-conversions + -cppcoreguidelines-narrowing-conversions + -altera-struct-pack-align + -cppcoreguidelines-prefer-member-initializer + ${CK_TIDY_CHECKS} + ${CK_TIDY_ERRORS} + HEADER_FILTER + "\.hpp$" + EXTRA_ARGS + -DCK_USE_CLANG_TIDY + ) -include(CppCheck) -enable_cppcheck( - CHECKS - warning - style - performance - portability - SUPPRESS - ConfigurationNotChecked - constStatement - duplicateCondition - noExplicitConstructor - passedByValue - preprocessorErrorDirective - shadowVariable - unusedFunction - unusedPrivateFunction - unusedStructMember - unmatchedSuppression - FORCE - SOURCES - library/src - INCLUDE - ${CMAKE_CURRENT_SOURCE_DIR}/include - ${CMAKE_CURRENT_BINARY_DIR}/include - ${CMAKE_CURRENT_SOURCE_DIR}/library/include - DEFINE - CPPCHECK=1 - __linux__=1 -) + include(CppCheck) + enable_cppcheck( + CHECKS + warning + style + performance + portability + SUPPRESS + ConfigurationNotChecked + constStatement + duplicateCondition + noExplicitConstructor + passedByValue + preprocessorErrorDirective + shadowVariable + unusedFunction + unusedPrivateFunction + unusedStructMember + unmatchedSuppression + FORCE + SOURCES + library/src + INCLUDE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_BINARY_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/library/include + DEFINE + CPPCHECK=1 + __linux__=1 + ) +else() + function(clang_tidy_check TARGET) + # stub out empty function if clang tidy is not enabled + endfunction() +endif() set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) @@ -538,7 +582,7 @@ if(BUILD_DEV) add_compile_options(-Werror) add_compile_options(-Weverything) endif() -message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang") add_compile_options(-fcolor-diagnostics) @@ -547,12 +591,15 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS add_compile_options(-fdiagnostics-color=always) endif() -# make check runs the entire set of examples and tests -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) -# make smoke runs the tests and examples that runs within 30 seconds on gfx90a -add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") -# make regression runs the tests and examples that runs for more 30 seconds on gfx90a -add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") +if(NOT MIOPEN_REQ_LIBS_ONLY) + # make check runs the entire set of examples and tests + add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + # make smoke runs the tests and examples that runs within 30 seconds on gfx90a + add_custom_target(smoke COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "SMOKE_TEST") + # make regression runs the tests and examples that runs for more 30 seconds on gfx90a + add_custom_target(regression COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "REGRESSION_TEST") +endif() + file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp") @@ -595,9 +642,14 @@ ENDIF() ENDFOREACH() add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) + +option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) +option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF) +option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) + add_subdirectory(library) -if(NOT GPU_ARCHS AND USER_GPU_TARGETS) +if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY) rocm_package_setup_component(tests LIBRARY_NAME composablekernel PACKAGE_NAME tests # Prevent -static suffix on package name @@ -608,16 +660,19 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) PACKAGE_NAME examples ) add_subdirectory(example) + add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) endif() endif() -rocm_package_setup_component(profiler - LIBRARY_NAME composablekernel - PACKAGE_NAME ckprofiler -) -add_subdirectory(profiler) +if (NOT MIOPEN_REQ_LIBS_ONLY) + rocm_package_setup_component(profiler + LIBRARY_NAME composablekernel + PACKAGE_NAME ckprofiler + ) + add_subdirectory(profiler) +endif() if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) add_subdirectory(codegen) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8ef5c2b726..0900b7a1f8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -20,10 +20,11 @@ Tejash Shah, 2019-2020 Xiaoyan Zhou, 2020 [Jianfeng Yan](https://github.com/j4yan), 2021-2022 - +[Jun Liu](https://github.com/junliume), 2021-2024 ## Product Manager -[Jun Liu](https://github.com/junliume) +[John Afaganis](https://github.com/afagaj) + ## Contributors diff --git a/Dockerfile b/Dockerfile index 17800d92d5..6f5cd0115d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.3 +ARG ROCMVERSION=6.4.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -9,19 +9,18 @@ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn # Add rocm repository RUN set -xe && \ - useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.4" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \ +RUN if [ "$ROCMVERSION" != "6.5" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ fi -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" && \ +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ amdgpu-install -y --usecase=rocm --no-dkms ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined @@ -44,17 +43,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- iputils-ping \ jq \ libelf-dev \ - libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ mpich \ net-tools \ pkg-config \ - python \ - python3 \ - python3-dev \ - python3-pip \ + python3-full \ redis \ rocm-llvm-dev \ sshpass \ @@ -67,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libzstd-dev \ openssh-server \ clang-format-12 \ + clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ @@ -74,17 +70,14 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- # Remove unnecessary rocm components that take a lot of space apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt -# Update the cmake to version 3.27.5 -RUN pip install --upgrade cmake==3.27.5 && \ #Install latest ccache - git clone https://github.com/ccache/ccache.git && \ +RUN git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip && \ gunzip /usr/local/bin/ninja.gz && \ chmod a+x /usr/local/bin/ninja && \ - git clone https://github.com/nico/ninjatracing.git && \ #Install ClangBuildAnalyzer git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ cd ClangBuildAnalyzer/ && \ @@ -98,8 +91,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results - pip3 install --upgrade pip && \ - pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools>=75 sshtunnel==0.4.0 && \ + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/Dockerfile.aiter b/Dockerfile.aiter new file mode 100644 index 0000000000..245e39fb75 --- /dev/null +++ b/Dockerfile.aiter @@ -0,0 +1,21 @@ +ARG BASE_DOCKER="rocm/pytorch:latest" +FROM $BASE_DOCKER +ARG AITER_BRANCH="main" +ARG CK_AITER_BRANCH="develop" +RUN groupadd -g 109 render && \ + usermod -u 1001 jenkins && \ + groupmod -g 1001 jenkins && \ + pip install pandas zmq einops && \ + pip install numpy==1.26.2 && \ + sudo mkdir /home/jenkins && \ + sudo mkdir /home/jenkins/workspace && \ + cd /home/jenkins/workspace && \ + rm -rf aiter && \ + git clone -b "$AITER_BRANCH" --recursive https://github.com/ROCm/aiter.git && \ + cd aiter && \ + rm -rf 3rdparty/composable_kernel/ && \ + git clone -b "$CK_AITER_BRANCH" https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \ + python3 setup.py develop && \ + chown -R jenkins:jenkins /home/jenkins/workspace && \ + chmod -R a+rwx /home/jenkins/workspace && \ + sudo usermod -aG irc jenkins diff --git a/Dockerfile.compiler b/Dockerfile.compiler index a22103b96b..0306057e45 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index 86cac3c485..d1f1baf15f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -12,6 +12,23 @@ def show_node_info() { """ } +class Version { + int major, minor, patch + @Override + String toString() { + return [major, minor, patch].findAll().join('.') + } +} +def parseVersion(String versionString) { + if (!versionString) return null + int[] tokens = versionString.split(/\./).collect { it as int } // Splits the string by '.' and converts each part to an integer. + return new Version( + major: tokens[0], + minor: tokens.length > 1 ? tokens[1] : null, + patch: tokens.length > 2 ? tokens[2] : null, + ) +} + def nthreads() { def nproc = sh(returnStdout: true, script: 'nproc') echo "Number of cores: ${nproc}" @@ -38,12 +55,12 @@ def getBaseDockerImageName(){ img = "${params.USE_CUSTOM_DOCKER}" } else{ - def ROCM_numeric = "${params.ROCMVERSION}" as float - if ( ROCM_numeric < 6.4 ){ - img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}" + def ROCM_numeric = parseVersion("${params.ROCMVERSION}") + if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}" + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm${params.ROCMVERSION}" } } return img @@ -76,6 +93,7 @@ def check_host() { if ("${env.CK_SCCACHE}" != "null"){ def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" echo "sccache server: ${SCCACHE_SERVER}" + sh "chmod +w -R ${env.WORKSPACE}" sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" @@ -92,6 +110,33 @@ def build_compiler(){ return compiler } +def check_arch(){ + def arch_type = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx90a" rocminfo.log') ){ + arch_type = 1 + } + else if ( runShell('grep -n "gfx942" rocminfo.log') ) { + arch_type = 2 + } + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { + arch_type = 3 + } + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { + arch_type = 4 + } + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { + arch_type = 5 + } + else if ( runShell('grep -n "gfx908" rocminfo.log') ) { + arch_type = 6 + } + else if ( runShell('grep -n "gfx950" rocminfo.log') ) { + arch_type = 7 + } + return arch_type +} + def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") @@ -107,6 +152,10 @@ def getDockerImage(Map conf=[:]){ image = conf.get("docker_name", "") echo "Using legacy docker: ${image}" } + else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){ + image = conf.get("docker_name", "") + echo "Using special docker: ${image}" + } else{ image = getDockerImageName() echo "Using default docker: ${image}" @@ -115,7 +164,7 @@ def getDockerImage(Map conf=[:]){ def retimage try { - echo "Pulling down image: ${image}" + echo "Pulling image: ${image}" retimage = docker.image("${image}") withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() @@ -139,12 +188,16 @@ def buildDocker(install_prefix){ if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " } + else if(params.RUN_AITER_TESTS){ + image_name = "rocm/composable_kernel:ck_aiter" + dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " + } else{ dockerArgs = dockerArgs + " -f Dockerfile . " } echo "Build Args: ${dockerArgs}" try{ - if(params.BUILD_DOCKER){ + if(params.BUILD_DOCKER || params.RUN_AITER_TESTS){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs) @@ -176,7 +229,9 @@ def cmake_build(Map conf=[:]){ def build_envs = "CTEST_PARALLEL_LEVEL=4 " + conf.get("build_env","") def prefixpath = conf.get("prefixpath","/opt/rocm") def setup_args = conf.get("setup_args","") - + // make sure all unit tests always run on develop branch + def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } @@ -238,6 +293,9 @@ def cmake_build(Map conf=[:]){ if (setup_args.contains("gfx94")){ invocation_tag="gfx94" } + if (setup_args.contains("gfx95")){ + invocation_tag="gfx95" + } echo "invocation tag: ${invocation_tag}" def redis_pre_setup_cmd = pre_setup_cmd if(check_host() && params.USE_SCCACHE && "${env.CK_SCCACHE}" != "null" && "${invocation_tag}" != "") { @@ -286,20 +344,19 @@ def cmake_build(Map conf=[:]){ def build_cmd def execute_cmd = conf.get("execute_cmd", "") if(!setup_args.contains("NO_CK_BUILD")){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ + def cmake_flags = params.NINJA_FTIME_TRACE ? "-O3 -ftime-trace" : "-O3" + if (params.NINJA_BUILD_TRACE) { echo "running ninja build trace" - setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) - build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") - } - else if (setup_args.contains("gfx908;gfx90a;gfx942")){ - //limit the number of build threads when building for multiple gfx9 targets - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j32 ${config_targets}") - } - else{ - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") } + setup_cmd = conf.get( + "setup_cmd", + """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" ${cmake_flags} " .. """ + ) + build_cmd = conf.get( + "build_cmd", + "${build_envs} ninja -j${nt} ${config_targets}" + ) + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -319,44 +376,61 @@ def cmake_build(Map conf=[:]){ sh cmd //run tests except when NO_CK_BUILD or BUILD_LEGACY_OS are set if(!setup_args.contains("NO_CK_BUILD") && !params.BUILD_LEGACY_OS){ - if (setup_args.contains("gfx90a") && params.NINJA_BUILD_TRACE){ - sh "/ninjatracing/ninjatracing .ninja_log > ck_build_trace.json" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" - sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" + if ((setup_args.contains("gfx9") && params.NINJA_BUILD_TRACE) || params.BUILD_INSTANCES_ONLY){ + if (params.NINJA_FTIME_TRACE) { + echo "running ninja ftime trace" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --all . clang_build.log" + sh "/ClangBuildAnalyzer/build/ClangBuildAnalyzer --analyze clang_build.log > clang_build_analysis.log" + archiveArtifacts "clang_build_analysis.log" + } + sh "python3 ../script/ninja_json_converter.py .ninja_log --legacy-format --output ck_build_trace.json" archiveArtifacts "ck_build_trace.json" - archiveArtifacts "clang_build_analysis.log" + // do not run unit tests when building instances only if(!params.BUILD_INSTANCES_ONLY){ - sh "ninja test" + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } + } + if(params.BUILD_INSTANCES_ONLY){ + // build deb packages + echo "Build packages" + sh 'ninja -j64 package' + archiveArtifacts artifacts: 'composablekernel-dev*.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_all_targets_1.1.0_amd64.deb' + stash includes: "composablekernel-dev_all_targets_1.1.0_amd64.deb", name: "packages" } } else{ - // run unit tests - sh "make check" + // run unit tests unless building library for all targets + if (!params.BUILD_INSTANCES_ONLY){ + if (!runAllUnitTests){ + sh "../script/launch_tests.sh" + } + else{ + sh "ninja check" + } + } } } } - // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { + // Only archive from develop + if (package_build == true && env.BRANCH_NAME == "develop") { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } //check the node gpu architecture - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } + def arch = check_arch() if (params.RUN_CK_TILE_FMHA_TESTS){ try{ archiveArtifacts "perf_fmha_*.log" - if (arch_type == 1){ + if (arch == 1){ stash includes: "perf_fmha_**_gfx90a.log", name: "perf_fmha_log_gfx90a" } - else if (arch_type == 2){ + else if (arch == 2){ stash includes: "perf_fmha_**_gfx942.log", name: "perf_fmha_log_gfx942" } } @@ -364,20 +438,6 @@ def cmake_build(Map conf=[:]){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." } } - if (params.RUN_CK_TILE_GEMM_TESTS){ - try{ - archiveArtifacts "perf_tile_gemm_**.log" - if (arch_type == 1){ - stash includes: "perf_tile_gemm_**_gfx90a.log", name: "perf_tile_gemm_log_gfx90a" - } - else if (arch_type == 2){ - stash includes: "perf_tile_gemm_**_gfx942.log", name: "perf_tile_gemm_log_gfx942" - } - } - catch(Exception err){ - echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." - } - } } def buildHipClangJob(Map conf=[:]){ @@ -385,26 +445,24 @@ def buildHipClangJob(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - - def image - if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ - image = conf.get("docker_name", "") - echo "Using legacy docker: ${image}" - } - else{ - image = getDockerImageName() - echo "Using default docker: ${image}" - } def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts + if ( params.BUILD_INSTANCES_ONLY ){ + dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } + else{ + dockerOpts = "--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + } if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') @@ -412,7 +470,7 @@ def buildHipClangJob(Map conf=[:]){ echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME - + def image def retimage (retimage, image) = getDockerImage(conf) @@ -453,27 +511,18 @@ def Build_CK(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 env.DOCKER_BUILDKIT=1 checkout scm - - def image - if ( params.BUILD_LEGACY_OS && conf.get("docker_name", "") != "" ){ - image = conf.get("docker_name", "") - echo "Using legacy docker: ${image}" - } - else{ - image = getDockerImageName() - echo "Using default docker: ${image}" - } - def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " + // the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with + // newer clang22 compilers and running with older hip runtima libraries + dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 " } if(params.BUILD_LEGACY_OS){ dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' " @@ -484,6 +533,7 @@ def Build_CK(Map conf=[:]){ echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME + def image def retimage gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { @@ -509,48 +559,35 @@ def Build_CK(Map conf=[:]){ timeout(time: 20, unit: 'HOURS') { //check whether to run performance tests on this node - def arch_type = 0 - sh 'rocminfo | tee rocminfo.log' - if ( runShell('grep -n "gfx90a" rocminfo.log') ){ - arch_type = 1 - } - else if ( runShell('grep -n "gfx942" rocminfo.log') ) { - arch_type = 2 - } - else if ( runShell('grep -n "gfx1030" rocminfo.log') ) { - arch_type = 3 - } - else if ( runShell('grep -n "gfx1101" rocminfo.log') ) { - arch_type = 4 - } - else if ( runShell('grep -n "gfx1201" rocminfo.log') ) { - arch_type = 5 - } - else if ( runShell('grep -n "gfx908" rocminfo.log') ) { - arch_type = 6 - } + def arch = check_arch() cmake_build(conf) - if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch == 1 ){ echo "Run inductor codegen tests" sh """ - pip install --verbose . - pytest python/test/test_gen_instances.py + python3 -m venv ${env.WORKSPACE} + . ${env.WORKSPACE}/bin/activate + python3 -m pip install pytest build setuptools setuptools_scm + python3 -m pip install . + python3 -m pytest python/test/test_gen_instances.py """ } dir("build"){ - if (params.RUN_FULL_QA && arch_type == 1 ){ - // build deb packages for all gfx9 targets on gfx90a system and prepare to export - echo "Build ckProfiler package" - sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash includes: "ckprofiler_0.2.0_amd64.deb", name: "ckprofiler_0.2.0_amd64.deb" + if (params.RUN_FULL_QA && arch == 2 ){ + // build deb packages + echo "Build packages" + sh 'ninja package' + archiveArtifacts artifacts: 'composablekernel*.deb' + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' + sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' + sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" } } // run performance tests, stash the logs, results will be processed on the master node dir("script"){ if (params.RUN_PERFORMANCE_TESTS){ - if (params.RUN_FULL_QA && arch_type == 1){ + if (params.RUN_FULL_QA && arch == 1){ // run full tests on gfx90a echo "Run full performance tests" sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -569,7 +606,7 @@ def Build_CK(Map conf=[:]){ archiveArtifacts "perf_mixed_gemm.log" stash includes: "perf_**.log", name: "perf_log" } - else if ( arch_type == 1 ){ + else if ( arch == 1 ){ // run standard tests on gfx90a echo "Run performance tests" sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" @@ -580,40 +617,44 @@ def Build_CK(Map conf=[:]){ stash includes: "perf_**.log", name: "perf_log" } // disable performance tests on gfx1030 for now. - //else if ( arch_type == 3){ + //else if ( arch == 3){ // run basic tests on gfx1030 // echo "Run gemm performance tests" // sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx10" // archiveArtifacts "perf_onnx_gemm_gfx10.log" // stash includes: "perf_onnx_gemm_gfx10.log", name: "perf_log_gfx10" //} - else if ( arch_type == 4){ + else if ( arch == 4){ // run basic tests on gfx11 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx11" archiveArtifacts "perf_onnx_gemm_gfx11.log" stash includes: "perf_onnx_gemm_gfx11.log", name: "perf_log_gfx11" } - else if ( arch_type == 5 ){ + else if ( arch == 5 ){ // run basic tests on gfx12 echo "Run gemm performance tests" sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx12" archiveArtifacts "perf_onnx_gemm_gfx12.log" stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" } - else if ( arch_type == 6 ){ - // run standard tests on gfx908 + else if ( arch == 6 ){ + // run basic tests on gfx908 echo "Run performance tests" - sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm_gfx908.log" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" archiveArtifacts "perf_onnx_gemm_gfx908.log" - archiveArtifacts "perf_resnet50_N256_gfx908.log" - archiveArtifacts "perf_resnet50_N4_gfx908.log" - stash includes: "perf_**.log", name: "perf_log_gfx908" + stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" + } + else if ( arch == 7 ){ + // run basic tests on gfx950 + echo "Run performance tests" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx950" + archiveArtifacts "perf_onnx_gemm_gfx950.log" + stash includes: "perf_onnx_gemm_gfx950.log", name: "perf_log_gfx950" } } } - if (params.hipTensor_test && arch_type == 1 ){ + if (params.hipTensor_test && arch == 1 ){ // build and test hipTensor on gfx90a node sh """#!/bin/bash rm -rf "${params.hipTensor_branch}".zip @@ -631,10 +672,6 @@ def Build_CK(Map conf=[:]){ """ } } - // set ownership of all files and folders to jenkins after all steps completed - dir("build"){ - sh "sudo chown -R jenkins:jenkins ../*" - } } } } @@ -660,7 +697,8 @@ def Build_CK_and_Reboot(Map conf=[:]){ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - def image = getDockerImageName() + //use older image that has user jenkins + def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group @@ -673,12 +711,17 @@ def process_results(Map conf=[:]){ def retimage gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e + catch(Exception ex) + { + error "Unable to locate image: ${image}" } } @@ -695,28 +738,10 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } - if (params.RUN_CK_TILE_GEMM_TESTS){ - try{ - unstash "perf_tile_gemm_log_gfx942" - unstash "perf_tile_gemm_log_gfx90a" - } - catch(Exception err){ - echo "could not locate the GEMM performance logs: ${err.getMessage()}." - } - } - if (params.RUN_FULL_QA){ - // unstash perf files to master - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - unstash "perf_log" - try{ - unstash "perf_log_gfx11" - unstash "perf_log_gfx12" - } - catch(Exception err){ - echo "could not locate the GEMM gfx11/gfx12 performance logs: ${err.getMessage()}." - } - sh "./process_qa_data.sh" + if (params.RUN_FULL_QA || params.BUILD_INSTANCES_ONLY){ + // unstash deb packages + unstash "packages" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master @@ -744,14 +769,63 @@ def process_results(Map conf=[:]){ } } -//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 22 * * * % ROCMVERSION=6.3;BUILD_GFX908=true;BUILD_GFX12=false;RUN_PERFORMANCE_TESTS=false - 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false - 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" +def run_aiter_tests(Map conf=[:]){ + show_node_info() + env.HSA_ENABLE_SDMA=0 + checkout scm + //use the latest pytorch image + def image = "rocm/composable_kernel:ck_aiter" + def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" + def variant = env.STAGE_NAME + def retimage + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" + + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } + } + catch(Exception ex) + { + error "Unable to locate image: ${image}" + } + } + + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 45, unit: 'MINUTES'){ + try{ + sh "rocminfo" + sh "python3 --version" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" + } + catch(e){ + echo "Throwing error exception while running AITER tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + echo "Finished running AITER tests" + } + } + } +} + +//launch develop branch daily jobs +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true + 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" pipeline { agent none @@ -772,8 +846,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.3', - description: 'Specify which ROCM version to use: 6.3 (default).') + defaultValue: '6.4.1', + description: 'Specify which ROCM version to use: 6.4.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -812,24 +886,28 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: true, - description: "Run the performance tests (default: ON)") + defaultValue: false, + description: "Run the performance tests (default: OFF)") booleanParam( name: "RUN_GROUPED_CONV_LARGE_CASES_TESTS", defaultValue: false, description: "Run the grouped conv large cases tests (default: OFF)") booleanParam( - name: "RUN_CODEGEN_TESTS", + name: "RUN_CONV_COMPREHENSIVE_DATASET", defaultValue: false, - description: "Run codegen tests (default: OFF)") + description: "Run comprehensive convolution dataset tests before important changes (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: true, + description: "Run codegen tests (default: ON)") booleanParam( name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, description: "Run the ck_tile FMHA tests (default: OFF)") booleanParam( - name: "RUN_CK_TILE_GEMM_TESTS", + name: "RUN_TILE_ENGINE_GEMM_TESTS", defaultValue: false, - description: "Run the ck_tile GEMM tests (default: OFF)") + description: "Run the tile_engine_gemm tests (default: OFF)") booleanParam( name: "BUILD_INSTANCES_ONLY", defaultValue: false, @@ -838,6 +916,26 @@ pipeline { name: "BUILD_GFX908", defaultValue: false, description: "Build CK and run tests on gfx908 (default: OFF)") + booleanParam( + name: "BUILD_GFX90A", + defaultValue: true, + description: "Build CK and run tests on gfx90a (default: ON)") + booleanParam( + name: "BUILD_GFX942", + defaultValue: false, + description: "Build CK and run tests on gfx942 (default: OFF)") + booleanParam( + name: "BUILD_GFX950", + defaultValue: false, + description: "Build CK and run tests on gfx950 (default: OFF)") + booleanParam( + name: "BUILD_GFX10", + defaultValue: true, + description: "Build CK and run tests on gfx10 (default: ON)") + booleanParam( + name: "BUILD_GFX11", + defaultValue: true, + description: "Build CK and run tests on gfx11 (default: ON)") booleanParam( name: "BUILD_GFX12", defaultValue: true, @@ -846,10 +944,34 @@ pipeline { name: "NINJA_BUILD_TRACE", defaultValue: false, description: "Generate a ninja build trace (default: OFF)") + booleanParam( + name: "NINJA_FTIME_TRACE", + defaultValue: false, + description: "Generate a detailed time trace (default: OFF)") booleanParam( name: "BUILD_LEGACY_OS", defaultValue: false, description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") + booleanParam( + name: "RUN_INDUCTOR_TESTS", + defaultValue: true, + description: "Run inductor codegen tests (default: ON)") + booleanParam( + name: "RUN_ALL_UNIT_TESTS", + defaultValue: false, + description: "Run all unit tests (default: OFF)") + booleanParam( + name: "RUN_AITER_TESTS", + defaultValue: false, + description: "Run AITER tests with latest CK develop branch (default: OFF)") + string( + name: 'aiter_branch', + defaultValue: 'main', + description: 'Specify which branch of AITER to use (default: main)') + string( + name: 'ck_aiter_branch', + defaultValue: 'develop', + description: 'Specify which branch of CK to test with AITER (default: develop)') } environment{ dbuser = "${dbuser}" @@ -892,7 +1014,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -921,7 +1043,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -930,6 +1052,24 @@ pipeline { } } } + stage("Run AITER Tests") + { + parallel + { + stage("Run AITER Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_AITER_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942")} + steps{ + run_aiter_tests() + cleanWs() + } + } + } + } stage("Run Grouped Conv Large Case Tests") { parallel @@ -944,8 +1084,35 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } + stage("Run Comprehensive Convolution Dataset Tests") + { + parallel + { + stage("Run Comprehensive Dataset Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CONV_COMPREHENSIVE_DATASET.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cd test_data && \ + ./generate_test_dataset.sh && \ + cd ../script && \ + ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 test_grouped_convnd_fwd_dataset_xdl && \ + ./bin/test_grouped_convnd_fwd_dataset_xdl""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -962,7 +1129,7 @@ pipeline { { when { beforeAgent true - expression { params.RUN_CODEGEN_TESTS.toBoolean() } + expression { params.RUN_CODEGEN_TESTS.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() } } agent{ label rocmnode("gfx90a")} environment{ @@ -1021,42 +1188,100 @@ pipeline { } } } - stage("Run CK_TILE_GEMM Tests") + stage("Run TILE_ENGINE_GEMM Tests") { parallel { - stage("Run CK_TILE_GEMM Tests on gfx90a") + stage("Run TILE_ENGINE_GEMM Tests on gfx90a") { when { beforeAgent true - expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 tile_example_gemm_universal && \ - cd ../ && - example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx90a" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_MULTI_D_DATATYPE="fp16" \ + -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() } } - stage("Run CK_TILE_GEMM Tests on gfx942") + stage("Run TILE_ENGINE_GEMM Tests on gfx942") { when { beforeAgent true - expression { params.RUN_CK_TILE_GEMM_TESTS.toBoolean() } + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } } agent{ label rocmnode("gfx942") } environment{ setup_args = "NO_CK_BUILD" - execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_gemm_universal && \ - cd ../ && - example/ck_tile/03_gemm/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx942" \ + -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -D GEMM_MULTI_D_DATATYPE="fp16" \ + -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ + ./bin/benchmark_gemm_multi_d_fp16_ccrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \ + ./bin/benchmark_gemm_multi_d_fp16_crrr && \ + ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \ + ./bin/benchmark_gemm_multi_d_fp16_rcrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1108,21 +1333,22 @@ pipeline { cleanWs() } } - stage("Build CK for all gfx9 targets") + stage("Build CK and run Tests on gfx942") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { (params.BUILD_GFX942.toBoolean() || params.RUN_FULL_QA.toBoolean()) && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } + agent{ label rocmnode("gfx942") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ + -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1130,23 +1356,26 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx942") + stage("Build CK and run Tests on gfx950") { when { beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX950.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx942") } + agent{ label rocmnode("gfx950") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx942" \ - -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DGPU_TARGETS="gfx950" \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1163,6 +1392,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1174,7 +1404,7 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX90A.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx90a") } environment{ @@ -1183,6 +1413,7 @@ pipeline { cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1190,22 +1421,27 @@ pipeline { cleanWs() } } - stage("Build CK instances for different targets") + stage("Build CK instances for all supported targets") { when { beforeAgent true expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } - environment{ - execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER="${build_compiler()}" \ - -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ - -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j32 """ - } + agent{ label rocmnode("gfx942") } steps{ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + script { + def execute_args = params.NINJA_FTIME_TRACE ? + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 -ftime-trace" .. && ninja -j64 """ : + """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ + + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + } cleanWs() } } @@ -1213,15 +1449,16 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX10.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1030") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx10-3-generic" -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1030" \ + -DGPU_TARGETS="gfx10-3-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1233,15 +1470,16 @@ pipeline { { when { beforeAgent true - expression { !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } + expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } agent{ label rocmnode("gfx1101") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1101" \ + -DGPU_TARGETS="gfx11-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ @@ -1257,11 +1495,12 @@ pipeline { } agent{ label rocmnode("gfx1201") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1201" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx12-generic" -DUSE_OPT_GFX12=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx1201" \ + -DGPU_TARGETS="gfx12-generic" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ diff --git a/README.md b/README.md index c316a0a322..459e17d9a3 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa 4. Build the entire CK library: ```bash - make -j + make -j"$(nproc)" ``` 5. Install CK: @@ -104,6 +104,7 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa ```bash make -j install ``` + **[See Note on -j](#notes)** ## Optional post-install steps @@ -146,7 +147,8 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html ``` -Note the `-j` option for building with multiple threads in parallel, which speeds up the build significantly. +### Notes +The `-j` option for building with multiple threads in parallel, which speeds up the build significantly. However, `-j` launches unlimited number of threads, which can cause the build to run out of memory and crash. On average, you should expect each thread to use ~2Gb of RAM. Depending on the number of CPU cores and the amount of RAM on your system, you may want to diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md index e8833efb89..6dbe88640c 100644 --- a/TERMINOLOGY.md +++ b/TERMINOLOGY.md @@ -1,2 +1,348 @@ [Back to the main page](./README.md) -# Composable Kernel terminology \ No newline at end of file + +# Composable Kernel Terminology + +This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations. + +--- + +## Glossary Index (Alphabetical) + +- [Add+Multiply](#addmultiply) +- [Bank Conflict](#bank-conflict) +- [Batched GEMM](#batched-gemm) +- [Benchmark](#benchmark) +- [Block Size](#block-size) +- [Block Tile](#block-tile) +- [Compute Unit (CU)](#compute-unit-cu) +- [Coordinate Transformation Primitives](#coordinate-transformation-primitives) +- [CUDA](#cuda) +- [Dense Tensor](#dense-tensor) +- [Descriptor](#descriptor) +- [Device](#device) +- [Elementwise](#elementwise) +- [Epilogue](#epilogue) +- [Fast Changing Dimension](#fast-changing-dimension) +- [GEMM](#gemm-general-matrix-multiply) +- [GEMV](#gemv) +- [Grouped GEMM](#grouped-gemm) +- [Global Memory](#global-memory) +- [Grid](#grid) +- [Host](#host) +- [HIP](#hip) +- [Inner Dimension](#inner-dimension) +- [Inner Product](#inner-product) +- [Input/Problem Shape](#inputproblem-shape) +- [Kernel](#kernel) +- [Launch Parameters](#launch-parameters) +- [Load Tile](#load-tile) +- [LDS Banks](#lds-banks) +- [Matrix Core](#matrix-core) +- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add) +- [Occupancy](#occupancy) +- [Outer Dimension](#outer-dimension) +- [Outer Product](#outer-product) +- [Pinned Memory](#pinned-memory) +- [Pipeline](#pipeline) +- [Policy](#policy) +- [Problem](#problem) +- [Processing Units](#processing-units) +- [Reference Kernel](#reference-kernel) +- [Regression Test](#regression-test) +- [ROCm](#rocm) +- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr) +- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share) +- [SIMT / SIMD](#simt--simd) +- [Smoke Test](#smoke-test) +- [Sparse Tensor](#sparse-tensor) +- [Split-K GEMM](#split-k-gemm) +- [Store Tile](#store-tile) +- [Thread / Work-item](#thread--work-item) +- [Thread Block / Work Group](#thread-block--work-group) +- [Vanilla GEMM](#vanilla-gemm) +- [Tile](#tile) +- [Tile Distribution](#tile-distribution) +- [Tile Partitioner](#tile-partitioner) +- [Tile Programming API](#tile-programming-api) +- [Tile Window](#tile-window) +- [User Customized Tile Pipeline](#user-customized-tile-pipeline) +- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization) +- [Vector](#vector) +- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr) +- [Warp / Wavefront](#warp--wavefront) +- [Wave Tile](#wave-tile) +- [XDL Instructions](#xdl-instructions) + +--- + +## 1. Hardware and Memory + +### Processing Units +The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale. + +### Matrix Core +Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores. + +### Compute Unit (CU) +AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_. + +### Matrix Fused Multiply-Add (MFMA) +AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance. + +### Registers +The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file. + +### Vector General Purpose Register (VGPR) +Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations. + +### Scalar General Purpose Register (SGPR) +Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave. + +### Shared Memory / Local Data Share (LDS) +AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict). + +### LDS Banks +Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth. + +### Global Memory +The main device memory accessible by all threads, offering high capacity but higher latency than shared memory. + +### Pinned Memory +Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements. + +### Dense Tensor +A tensor in which most elements are nonzero, typically stored in a contiguous block of memory. + +### Sparse Tensor +A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices. + +### Host +CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation. + +### Device +GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators. + +--- + +## 2. GPU Programming Model + +### Thread / Work-item +AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing. + +### Warp / Wavefront +AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp. + +### Thread Block / Work Group +AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution. + +### Grid +The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU. + +### Block Size +Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage. + +### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD) +SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures. + +### Occupancy +The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput. + +--- + +## 3. Kernel Structure + +### Kernel +A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component. + +### Pipeline +A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around. + +### Tile Partitioner +Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes. + +### Problem +Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution). + +### Policy +Defines memory access patterns and hardware-specific optimizations. + +### User Customized Tile Pipeline +User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### User Customized Tile Pipeline Optimization +Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### Tile Programming API +CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store. + +### Coordinate Transformation Primitives +CK utilities for converting between different coordinate systems (logical, physical, memory layouts). + +### Reference Kernel +A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU. + +### Launch Parameters +Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance. + +--- + +## 4. Memory Access and Data Layout + +### Memory Coalescing +An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth. + +### Alignment +A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value. + +### Bank Conflict +Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth. + +### Padding +The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access. + +### Permute/Transpose +Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns. + +### Host-Device Transfer +The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations. + +### Stride +The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step. + +### Dilation +The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size. + +### Im2Col/Col2Im +Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout. + +### Fast Changing Dimension +Innermost dimension that changes fastest in memory layout. + +### Outer Dimension +Slower-changing dimension in memory layout. + +### Inner Dimension +Faster-changing dimension in memory layout. + +--- + +## 5. Tile-Based Computing and Data Structures + +### Tile +A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms. + +### Block Tile +Memory tile processed by a work group (thread block). + +### Wave Tile +Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution. + +### Tile Distribution +Hierarchical data mapping from work-items to data in memory. + +### Tile Window +Viewport into a larger tensor that defines the current tile's position and boundaries for computation. + +### Load Tile +Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns. + +### Store Tile +Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns. + +### Descriptor +Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations. + +### Input/Problem Shape +Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM). + +### Vector +Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware. + +--- + +## 6. Kernel Operations and Optimization + +### Elementwise +Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access. + +### Epilogue +The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs. + +### Add+Multiply +A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers. + +--- + +## 7. Linear Algebra and ML Operations + +### General Matrix Multiply (GEMM) +Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C. + +### "Vanilla" GEMM (Naive GEMM) Kernel +The **vanilla GEMM** is the simplest form of GEMM in CK. It: +- Takes input matrices **A** and **B** +- Multiplies them to produce output matrix **C** + +This is the **baseline** or **building block** GEMM that all other complex versions expand upon. + +### Grouped GEMM (GGEMMs) + +A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel. + +### Batched GEMM +A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape. + +### Split-K GEMM +A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications. + +### GEMV +The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing. + +### Inner Product +Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar. + +### Outer Product +The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms. + +### Norm +A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization. + +--- + +## 8. Testing, Build, and Infrastructure + +### Regression Test +Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942. + +### Smoke Test +Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942. + +--- + +## 9. Low-Level Instructions and Optimizations + +### eXtensible Data Language (XDL) Instructions +eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks. + +--- + +## 10. Miscellaneous + +### HIP +AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures. + +### CUDA +NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources. + +### ROCm +AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs. + +--- + +## Scientific Context and References + +This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see: +- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html) +- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html) + +This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles. diff --git a/client_example/07_grouped_convnd_fwd/README.md b/client_example/07_grouped_convnd_fwd/README.md index 28a64ad733..9e96df222d 100644 --- a/client_example/07_grouped_convnd_fwd/README.md +++ b/client_example/07_grouped_convnd_fwd/README.md @@ -30,14 +30,14 @@ List of the device operations for grouped convolution forward in CK: Table of supported cases by instance factory with XDL instruction: -| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| -|-------|---|---|---| -|bf16 |2D, 3D|2D|1D, 2D, 3D| -|fp16 |2D, 3D|2D|1D, 2D, 3D| -|fp32 |2D, 3D|2D|1D, 2D, 3D| -|int8 |2D, 3D|2D|1D, 3D| -|fp8 |3D|✗|✗| -|bf8 |3D|✗|✗| +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp16 |2D, 3D|2D|2D|1D, 2D, 3D| +|fp32 |2D, 3D|2D|2D|1D, 2D, 3D| +|int8 |2D, 3D|2D|2D|1D, 3D| +|fp8 |3D|✗|✗|✗| +|bf8 |3D|✗|✗|✗| Table of supported cases by instance factory with WMMA instruction: diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp index 480abf23d2..13f1a3acc1 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -107,14 +107,14 @@ int execute_conv_fwd() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index ae5f1b6f6e..f31ffe302a 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -130,14 +130,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp index 2309d757f0..a9918f6ab3 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -105,14 +105,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index 93709a7901..baa2b02bce 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -109,14 +109,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index a62a1d911b..ac7eb3cf41 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -111,14 +111,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/11_grouped_conv_bwd_weight/README.md b/client_example/11_grouped_conv_bwd_weight/README.md index 834fd62c8f..f1ba95e9cd 100644 --- a/client_example/11_grouped_conv_bwd_weight/README.md +++ b/client_example/11_grouped_conv_bwd_weight/README.md @@ -34,12 +34,12 @@ List of the device operations for grouped convolution backward weight in CK: Table of supported cases by instance factory with XDL instruction: -| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|GNHWC/GKYXC/GNHWK| -|-------|---|---|---| -|bf16|2D, 3D|2D, 3D|✗| -|bf16(fp32 for weight)|2D, 3D|✗|1D, 2D, 3D| -|fp16 |2D, 3D|2D, 3D|1D, 2D, 3D| -|fp32 |2D, 3D|2D, 3D|1D, 2D, 3D| +| |NHWGC/GKYXC/NHWGK|NGCHW/GKYXC/NGKHW|NGCHW/GKCYX/NGKHW|GNHWC/GKYXC/GNHWK| +|-------|---|---|---|---| +|bf16|2D, 3D|2D, 3D|2D, 3D|✗| +|bf16(fp32 for weight)|2D, 3D|✗|✗|1D, 2D, 3D| +|fp16 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| +|fp32 |2D, 3D|2D, 3D|2D, 3D|1D, 2D, 3D| Table of supported cases by instance factory with WMMA instruction: diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 69d7c8936c..37cafc190e 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -59,7 +59,7 @@ int main() SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); std::array ab_input = {a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer()}; + b_dev_buf.GetDeviceBuffer()}; std::vector abStride = {Stride, 1}; std::array, 2> abStrides = {abStride, abStride}; diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index e2b1fbcb54..12aa31dec3 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -68,15 +68,15 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); using DeviceOp = ck::tensor_operation::device::DeviceReduce; + AccDataType, + OutDataType, + Rank, + NumReduceDim, + ReduceAdd, + PassThrough, + UnaryDivide, + PropagateNan, + OutputIndex>; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp index bb106e8d8e..e8e33a3de2 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {in.GetDeviceBuffer()}, + {in.GetDeviceBuffer()}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {in_lengths}, - {in_strides}, + {in_lengths}, + {in_strides}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp index e53ecc6c99..d81b5fd03e 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp index 32ab481319..2ec70b8b9b 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {out.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {out_lengths}, - {out_strides}, + {out_lengths}, + {out_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp index c78cacf266..98f41dc7fb 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce( ck::tensor_operation::element_wise::Scale{scale_wei}, {}}; auto conv_ok = ConvolutionScale(in, + WeiDataType, + ConvOutDataType, + ConvElementOp, + InLayout, + WeiLayout, + OutLayout, + NumDimSpatial>(in, wei, conv_out, elementwise_op, @@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor, { std::cout << "\nReduction of spatial dimensions:" << std::endl; using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + OutDataType, + OutDataType, + NumDimSpatial, + NumDimSpatial, + ReduceOperation, + PassThrough, + AccElementwiseOperation, + true, // PropagateNan + false>; // OutputIndex const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp index 11e69f5bb2..11f24b39c7 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -120,14 +120,14 @@ int execute_conv_fwd_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 3f6f7b0773..4cf3a4cf82 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab() in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index ceccc5eb8f..f7f893fda2 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G, ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; + decltype(output_tensor_global), + decltype(tile_shape), + decltype(thread_layout)>; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, dim3(grid_size_x, grid_size_y, 1), diff --git a/client_example/32_gemm_mx/CMakeLists.txt b/client_example/32_gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..558986bf5a --- /dev/null +++ b/client_example/32_gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx950") + add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp) + target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/32_gemm_mx/gemm_mx_fp8.cpp b/client_example/32_gemm_mx/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..6e14bf2a5f --- /dev/null +++ b/client_example/32_gemm_mx/gemm_mx_fp8.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using CDataType = ck::half_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; +template +inline constexpr bool is_same_v = ck::is_same::value; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AScaleLayout = Row; +using BScaleLayout = Col; + +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + mem_size_ = mem_size; + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mem_size_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; + + /* Require by mx type*/ + constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + + if(argc == 1) + { + // use default case + } + else if(argc == 7) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + K = std::stoi(argv[3]); + + StrideA = std::stoi(argv[4]); + StrideB = std::stoi(argv[5]); + StrideC = std::stoi(argv[6]); + } + else + { + printf("arg1 to 6: M, N, K, StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + /* Scale stride Calculation */ + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + return static_cast(col); + else + return static_cast(row); + } + else + return static_cast(stride); + }; + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize; + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{})); + SimpleDeviceMem a_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + SimpleDeviceMem b_scale_device_buf( + sizeof(XDataType) * + f_matrix_space_size(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + using DeviceOp = + ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 9e2012bf8a..f27e557cc3 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.15) project(ck_app) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) if (DTYPES) add_definitions(-DDTYPES) @@ -32,7 +32,7 @@ if (DTYPES) add_definitions(-DCK_ENABLE_BF16) set(CK_ENABLE_BF16 "ON") endif() - message("DTYPES macro set to ${DTYPES}") + message(DEBUG "DTYPES macro set to ${DTYPES}") else() add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_INT8 "ON") diff --git a/client_example/README.md b/client_example/README.md index d9f793434d..34c6733d05 100644 --- a/client_example/README.md +++ b/client_example/README.md @@ -14,8 +14,10 @@ cd client_example/build cmake \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm;${PATH_TO_CK_INSTALL_DIRECTORY}" \ +-D GPU_TARGETS="gfx908;gfx90a" \ .. ``` +You must set the `GPU_TARGETS` macro to specify the GPU target architecture(s). ### Build client example ```bash diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..0c81f8df98 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,8 @@ else() -Wunreachable-code -Wunused -Wno-reserved-identifier - -Werror + # Werror set outside by BUILD_DEV + # -Werror -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt @@ -108,7 +109,7 @@ else() endif() list(APPEND CMAKE_COMPILER_WARNINGS -Wno-missing-field-initializers - -Wno-deprecated-declarations + -Wno-error=deprecated-declarations ) endif() add_definitions(${CMAKE_COMPILER_WARNINGS}) diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..6587f4c4be 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,3 +68,6 @@ endif() target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) +target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) + diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 9e7c360f54..2b2e6e2949 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -19,12 +19,10 @@ list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS ${CK_ROOT}/include/ck/*.hpp) -# printouts fot debug purposes -# message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") -# message(STATUS "RELATIVE: ${CK_ROOT}/include") + add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library @@ -48,6 +46,7 @@ rocm_install_targets( INCLUDE include ) rocm_export_targets( + TARGETS ck_host ck_headers EXPORT ck_host_targets NAMESPACE composable_kernel:: ) diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 89c1884d2e..81b312ec95 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector -inline auto Transform(const Range1& r1, const Range2& r2, F f) - -> std::vector +inline auto Transform(const Range1& r1, + const Range2& r2, + F f) -> std::vector { std::vector result; assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index 36c9a13b4c..a2f322c50f 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -142,12 +142,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr x.A = TensorDesc{prob.ADataType, prob.ALayout}; x.B = TensorDesc{prob.BDataType, prob.BLayout}; x.E = TensorDesc{prob.EDataType, prob.ELayout}; - x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { - return TensorDesc{dt, lo}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; + x.Ds = Transform( + prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; x.update_prologue(prologue); x.update_epilogue(epilogue); result.push_back(x); diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 13035df355..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index adc8e1ff02..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 2e7ceb5648..b8a60cd633 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -8,5 +8,5 @@ target_link_libraries(ck_rtc PUBLIC -lstdc++fs) option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) if(USE_HIPRTC_FOR_CODEGEN_TESTS) target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) - message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") + message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") endif() diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 2f3b26cc43..f4983debd9 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -16,7 +16,7 @@ struct tmp_dir void execute(const std::string& cmd) const; - tmp_dir(tmp_dir const&) = delete; + tmp_dir(tmp_dir const&) = delete; tmp_dir& operator=(tmp_dir const&) = delete; ~tmp_dir(); diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 262e6bae46..fac92ded7d 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -94,7 +94,7 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o assert(not srcs.empty()); tmp_dir td{"compile"}; options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " --offload-arch=" + get_device_name(); std::string out; @@ -278,7 +278,7 @@ std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, compile_options options) { options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " -DCK_CODE_GEN_RTC"; options.flags += " --offload-arch=" + get_device_name(); auto cos = compile_hip_src_with_hiprtc(srcs, options); diff --git a/docs/reference/Supported_Primitives_Guide.rst b/docs/conceptual/Composable-Kernel-math.rst similarity index 85% rename from docs/reference/Supported_Primitives_Guide.rst rename to docs/conceptual/Composable-Kernel-math.rst index e24acf5656..1c21fd8a11 100644 --- a/docs/reference/Supported_Primitives_Guide.rst +++ b/docs/conceptual/Composable-Kernel-math.rst @@ -1,18 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel mathematical basis + :keywords: composable kernel, CK, ROCm, API, mathematics, algorithm .. _supported-primitives: ******************************************************************** -Supported Primitives Guide +Composable Kernel mathematical basis ******************************************************************** -This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK. +This is an introduction to the math which underpins the algorithms implemented in Composable Kernel. ------------- -Softmax ------------- For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` you can decompose the softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as, diff --git a/docs/conceptual/Composable-Kernel-structure.rst b/docs/conceptual/Composable-Kernel-structure.rst new file mode 100644 index 0000000000..43c3603b95 --- /dev/null +++ b/docs/conceptual/Composable-Kernel-structure.rst @@ -0,0 +1,29 @@ +.. meta:: + :description: Composable Kernel structure + :keywords: composable kernel, CK, ROCm, API, structure + +.. _what-is-ck: + +******************************************************************** +Composable Kernel structure +******************************************************************** + +The Composable Kernel library uses a tile-based programming model and tensor coordinate transformation to achieve performance portability and code maintainability. Tensor coordinate transformation is a complexity reduction technique for complex machine learning operators. + + +.. image:: ../data/ck_component.png + :alt: CK Components + + +The Composable Kernel library consists of four layers: + +* a templated tile operator layer +* a templated kernel and invoker layer +* an instantiated kernel and invoker layer +* a client API layer. + +A wrapper component is included to simplify tensor transform operations. + +.. image:: ../data/ck_layer.png + :alt: CK Layers + \ No newline at end of file diff --git a/docs/conceptual/what-is-ck.rst b/docs/conceptual/what-is-ck.rst deleted file mode 100644 index 36785fc6ca..0000000000 --- a/docs/conceptual/what-is-ck.rst +++ /dev/null @@ -1,41 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _what-is-ck: - -******************************************************************** -What is the Composable Kernel library -******************************************************************** - - -Methodology -=========== - -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. - -CK utilizes two concepts to achieve performance portability and code maintainability: - -* A tile-based programming model -* Algorithm complexity reduction for complex ML operators using an innovative technique called - "Tensor Coordinate Transformation". - -.. image:: ../data/ck_component.png - :alt: CK Components - - -Code Structure -============== - -The CK library is structured into 4 layers: - -* "Templated Tile Operators" layer -* "Templated Kernel and Invoker" layer -* "Instantiated Kernel and Invoker" layer -* "Client API" layer - -It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. - -.. image:: ../data/ck_layer.png - :alt: CK Layers - \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index e8617a09ef..fe8a1c1d79 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,7 @@ external_toc_path = "./sphinx/_toc.yml" docs_core = ROCmDocs(left_nav_title) docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") +docs_core.enable_api_reference() docs_core.setup() external_projects_current_project = "composable_kernel" diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index fac9e138e1..4c8019f8d3 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -1,4 +1,4 @@ -# Doxyfile 1.8.10 +# Doxyfile 1.9.7 # This file describes the settings to be used by the documentation system # doxygen (www.doxygen.org) for a project. @@ -12,16 +12,26 @@ # For lists, items can also be appended using: # TAG += value [value, ...] # Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. # The default value is: UTF-8. DOXYFILE_ENCODING = UTF-8 @@ -32,26 +42,26 @@ DOXYFILE_ENCODING = UTF-8 # title of most generated pages and in a few other places. # The default value is: My Project. -PROJECT_NAME = "ck" +PROJECT_NAME = "Composable Kernel" # The PROJECT_NUMBER tag can be used to enter a project or revision number. This # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v3.0.1.0 +PROJECT_NUMBER = # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a # quick idea about the purpose of the project. Keep the description short. -PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HiP" +PROJECT_BRIEF = "Prototype interfaces compatible with ROCm platform and HiP" # With the PROJECT_LOGO tag one can specify a logo or an icon that is included # in the documentation. The maximum height of the logo should not exceed 55 # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -60,16 +70,28 @@ PROJECT_LOGO = OUTPUT_DIRECTORY = . -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this # option can be useful when feeding doxygen a huge amount of source files, where # putting all generated files in the same directory would otherwise causes -# performance problems for the file system. +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. # The default value is: NO. CREATE_SUBDIRS = NO +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + # If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII # characters to appear in the names of generated files. If set to NO, non-ASCII # characters will be escaped, for example _xE3_x81_x84 will be used for Unicode @@ -81,14 +103,14 @@ ALLOW_UNICODE_NAMES = NO # The OUTPUT_LANGUAGE tag is used to specify the language in which all # documentation generated by doxygen is written. Doxygen will use this # information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. # The default value is: English. OUTPUT_LANGUAGE = English @@ -162,7 +184,8 @@ FULL_PATH_NAMES = YES # will be relative from the directory where doxygen is started. # This tag requires that the tag FULL_PATH_NAMES is set to YES. -STRIP_FROM_PATH = +#STRIP_FROM_PATH = +STRIP_FROM_PATH = /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/latest/ # The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the # path mentioned in the documentation of a class, which tells the reader which @@ -171,7 +194,8 @@ STRIP_FROM_PATH = # specify the list of include paths that are normally passed to the compiler # using the -I flag. -STRIP_FROM_INC_PATH = +STRIP_FROM_INC_PATH = + # If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but # less readable) file names. This can be useful is your file systems doesn't @@ -189,6 +213,16 @@ SHORT_NAMES = NO JAVADOC_AUTOBRIEF = NO +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + # If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first # line (until the first dot) of a Qt-style comment as the brief description. If # set to NO, the Qt-style will behave just like regular Qt-style comments (thus @@ -209,6 +243,14 @@ QT_AUTOBRIEF = NO MULTILINE_CPP_IS_BRIEF = NO +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + # If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the # documentation from any documented member that it re-implements. # The default value is: YES. @@ -232,20 +274,19 @@ TAB_SIZE = 4 # the documentation. An alias has the form: # name=value # For example adding -# "sideeffect=@par Side Effects:\n" +# "sideeffect=@par Side Effects:^^" # will allow you to put the command \sideeffect (or @sideeffect) in the # documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) ALIASES = -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - # Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources # only. Doxygen will then generate output that is more tailored for C. For # instance, some of the names that are used will be different. The list of all @@ -274,28 +315,40 @@ OPTIMIZE_FOR_FORTRAN = NO OPTIMIZE_OUTPUT_VHDL = NO +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + # Doxygen selects the parser to use depending on the extension of the files it # parses. With this tag you can assign which parser to use for a given # extension. Doxygen has a built-in mapping, but you can override or extend it # using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. # # Note: For files without extension you can use no_extension as a placeholder. # # Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. EXTENSION_MAPPING = # If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments # according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. +# documentation. See https://daringfireball.net/projects/markdown/ for details. # The output of markdown processing is further processed by doxygen, so you can # mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in # case of backward compatibilities issues. @@ -303,6 +356,26 @@ EXTENSION_MAPPING = MARKDOWN_SUPPORT = YES +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN Use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0. and GITHUB Use the lower case version of title +# with any whitespace replaced by '-' and punctations characters removed.. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + # When enabled doxygen tries to link words that correspond to documented # classes, or namespaces to their corresponding documentation. Such a link can # be prevented in individual cases by putting a % sign in front of the word or @@ -328,7 +401,7 @@ BUILTIN_STL_SUPPORT = YES CPP_CLI_SUPPORT = NO # Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen # will parse them like normal C++ but will assume all classes use public instead # of private inheritance when no explicit protection keyword is present. # The default value is: NO. @@ -414,6 +487,27 @@ TYPEDEF_HIDES_STRUCT = YES LOOKUP_CACHE_SIZE = 0 +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = YES + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- @@ -434,6 +528,12 @@ EXTRACT_ALL = YES EXTRACT_PRIVATE = NO +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + # If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal # scope will be included in the documentation. # The default value is: NO. @@ -471,6 +571,13 @@ EXTRACT_LOCAL_METHODS = NO EXTRACT_ANON_NSPACES = NO +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + # If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all # undocumented members inside documented classes or files. If set to NO these # members will be included in the various overviews, but no documentation @@ -482,14 +589,15 @@ HIDE_UNDOC_MEMBERS = NO # If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set # to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. # The default value is: NO. HIDE_UNDOC_CLASSES = NO # If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. +# declarations. If set to NO, these declarations will be included in the +# documentation. # The default value is: NO. HIDE_FRIEND_COMPOUNDS = NO @@ -508,12 +616,20 @@ HIDE_IN_BODY_DOCS = NO INTERNAL_DOCS = NO -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. CASE_SENSE_NAMES = NO @@ -531,6 +647,12 @@ HIDE_SCOPE_NAMES = NO HIDE_COMPOUND_REFERENCE= NO +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + # If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of # the files that are included by a file in the documentation of that file. # The default value is: YES. @@ -688,7 +810,8 @@ FILE_VERSION_FILTER = # output files in an output format independent way. To create the layout file # that represents doxygen's defaults, run doxygen with the -l option. You can # optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. # # Note that if you run doxygen from a directory containing a file called # DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE @@ -699,7 +822,7 @@ LAYOUT_FILE = # The CITE_BIB_FILES tag can be used to specify one or more bib files containing # the reference definitions. This must be a list of .bib files. The .bib # extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. # For LaTeX the style of the bibliography can be controlled using # LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the # search path. See also \cite for info how to create references. @@ -734,34 +857,81 @@ WARNINGS = YES WARN_IF_UNDOCUMENTED = YES # If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. # The default value is: YES. WARN_IF_DOC_ERROR = YES +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + # This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that # are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC # The default value is: NO. WARN_NO_PARAMDOC = NO +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. +# The default value is: NO. + +WARN_AS_ERROR = NO + # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which # will be replaced by the file and line number from which the warning originated # and the warning text. Optionally the format may contain $version, which will # be replaced by the version of the file (if it could be obtained via # FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT # The default value is: $file:$line: $text. WARN_FORMAT = "$file:$line: $text" +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + # The WARN_LOGFILE tag can be used to specify a file to which warning and error # messages should be written. If left blank the output is written to standard -# error (stderr). +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). WARN_LOGFILE = @@ -775,22 +945,31 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../../include/ck/tensor_operation/gpu/grid \ - ../../include/ck/tensor_operation/gpu/block \ - ../../include/ck/tensor_operation/gpu/thread \ +INPUT = ../../include \ + ../../include/ck/ \ ../../library/include/ck/library/utility \ - ../../include/ck/wrapper - + ../../include/ck_tile # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING # The default value is: UTF-8. INPUT_ENCODING = UTF-8 +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + # If the value of the INPUT tag contains directories, you can use the # FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and # *.h) to filter out the source-files in the directories. @@ -799,11 +978,15 @@ INPUT_ENCODING = UTF-8 # need to set EXTENSION_MAPPING for the extension otherwise the files are not # read by doxygen. # +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# # If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, # *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, -# *.vhdl, *.ucf, *.qsf, *.as and *.js. +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. FILE_PATTERNS = *.c \ *.cc \ @@ -824,6 +1007,7 @@ FILE_PATTERNS = *.c \ *.hxx \ *.hpp \ *.h++ \ + *.l \ *.cs \ *.d \ *.php \ @@ -837,13 +1021,19 @@ FILE_PATTERNS = *.c \ *.mm \ *.dox \ *.py \ - *.tcl \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f18 \ + *.f \ + *.for \ *.vhd \ *.vhdl \ *.ucf \ *.qsf \ - *.as \ - *.js + *.ice # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. @@ -880,10 +1070,7 @@ EXCLUDE_PATTERNS = # (namespaces, classes, functions, etc.) that should be excluded from the # output. The symbol name can be a fully qualified name, a word, or if the # wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* +# ANamespace::AClass, ANamespace::*Test EXCLUDE_SYMBOLS = @@ -927,6 +1114,15 @@ IMAGE_PATH = # Note that the filter must not add or remove lines; it is applied before the # code is scanned, but not when the output code is generated. If lines are added # or removed, the anchors will not be placed correctly. +# +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. INPUT_FILTER = @@ -936,6 +1132,10 @@ INPUT_FILTER = # (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how # filters are used. If the FILTER_PATTERNS tag is empty or if none of the # patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. FILTER_PATTERNS = @@ -959,7 +1159,17 @@ FILTER_SOURCE_PATTERNS = # (index.html). This can be useful if you have a project on for instance GitHub # and want to reuse the introduction page also for the doxygen output. -USE_MDFILE_AS_MAINPAGE = ../README.md + +USE_MDFILE_AS_MAINPAGE = + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 #--------------------------------------------------------------------------- # Configuration options related to source browsing @@ -988,7 +1198,7 @@ INLINE_SOURCES = NO STRIP_CODE_COMMENTS = YES # If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. +# entity all documented functions referencing it will be listed. # The default value is: NO. REFERENCED_BY_RELATION = NO @@ -1020,12 +1230,12 @@ SOURCE_TOOLTIPS = YES # If the USE_HTAGS tag is set to YES then the references to source code will # point to the HTML generated by the htags(1) tool instead of doxygen built-in # source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version +# (see https://www.gnu.org/software/global/global.html). You will need version # 4.8.6 or higher. # # To use it do the following: # - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file # - Make sure the INPUT points to the root of the source tree # - Run doxygen as normal # @@ -1047,25 +1257,6 @@ USE_HTAGS = NO VERBATIM_HEADERS = YES -# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the -# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the -# cost of reduced performance. This can be particularly helpful with template -# rich C++ code for which doxygen's built-in parser lacks the necessary type -# information. -# Note: The availability of this option depends on whether or not doxygen was -# compiled with the --with-libclang option. -# The default value is: NO. - -CLANG_ASSISTED_PARSING = NO - -# If clang assisted parsing is enabled you can provide the compiler with command -# line options that you would normally use when invoking the compiler. Note that -# the include paths will already be set by doxygen for the files and directories -# specified with INPUT and INCLUDE_PATH. -# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. - -CLANG_OPTIONS = - #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- @@ -1077,17 +1268,11 @@ CLANG_OPTIONS = ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. # This tag requires that the tag ALPHABETICAL_INDEX is set to YES. IGNORE_PREFIX = @@ -1134,7 +1319,7 @@ HTML_FILE_EXTENSION = .html # of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_HEADER = +HTML_HEADER = ../_doxygen/header.html # The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each # generated HTML page. If the tag is left blank doxygen will generate a standard @@ -1144,7 +1329,7 @@ HTML_HEADER = # that doxygen normally uses. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_FOOTER = +HTML_FOOTER = ../_doxygen/footer.html # The HTML_STYLESHEET tag can be used to specify a user-defined cascading style # sheet that is used by each HTML page. It can be used to fine-tune the look of @@ -1156,7 +1341,7 @@ HTML_FOOTER = # obsolete. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_STYLESHEET = +HTML_STYLESHEET = ../_doxygen/stylesheet.css # The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined # cascading style sheets that are included after the standard style sheets @@ -1166,10 +1351,15 @@ HTML_STYLESHEET = # Doxygen will copy the style sheet files to the output directory. # Note: The order of the extra style sheet files is of importance (e.g. the last # style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_STYLESHEET = +HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css # The HTML_EXTRA_FILES tag can be used to specify one or more extra images or # other source files which should be copied to the HTML output directory. Note @@ -1179,21 +1369,34 @@ HTML_EXTRA_STYLESHEET = # files will be copied as-is; there are no commands or markers available. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_FILES = +HTML_EXTRA_FILES = ../_doxygen/extra_stylesheet.css + +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = LIGHT # The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen # will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value # 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 # purple, and 360 is red again. # Minimum value: 0, maximum value: 359, default value: 220. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_COLORSTYLE_HUE = 220 +HTML_COLORSTYLE_HUE = 240 # The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A +# in the HTML output. For a value of 0 the output will use gray-scales only. A # value of 255 will produce the most vivid colors. # Minimum value: 0, maximum value: 255, default value: 100. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1211,14 +1414,16 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_TIMESTAMP = NO +HTML_DYNAMIC_MENUS = YES # If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML # documentation will contain sections that can be hidden and shown after the @@ -1243,13 +1448,14 @@ HTML_INDEX_NUM_ENTRIES = 100 # If the GENERATE_DOCSET tag is set to YES, additional index files will be # generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in # ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1263,6 +1469,13 @@ GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + # This tag specifies a string that should uniquely identify the documentation # set bundle. This should be a reverse domain-name style string, e.g. # com.mycompany.MyDocSet. Doxygen will append .docset to the name. @@ -1288,8 +1501,12 @@ DOCSET_PUBLISHER_NAME = Publisher # If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three # additional HTML index files: index.hhp, index.hhc, and index.hhk. The # index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). # # The HTML Help Workshop contains a compiler that can convert all HTML output # generated by doxygen into a single compiled HTML file (.chm). Compiled HTML @@ -1319,7 +1536,7 @@ CHM_FILE = HHC_LOCATION = # The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). +# (YES) or that it should be included in the main .chm file (NO). # The default value is: NO. # This tag requires that the tag GENERATE_HTMLHELP is set to YES. @@ -1346,6 +1563,16 @@ BINARY_TOC = NO TOC_EXPAND = NO +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + # If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and # QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that # can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help @@ -1364,7 +1591,8 @@ QCH_FILE = # The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help # Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). # The default value is: org.doxygen.Project. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1372,8 +1600,8 @@ QHP_NAMESPACE = org.doxygen.Project # The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt # Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). # The default value is: doc. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1381,30 +1609,30 @@ QHP_VIRTUAL_FOLDER = doc # If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom # filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_NAME = # The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the # custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_ATTRS = # The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this # project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_SECT_FILTER_ATTRS = -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. # This tag requires that the tag GENERATE_QHP is set to YES. QHG_LOCATION = @@ -1447,16 +1675,28 @@ DISABLE_INDEX = NO # to work a browser that supports JavaScript, DHTML, CSS and frames is required # (i.e. any modern browser). Windows users are probably better off using the # HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. GENERATE_TREEVIEW = NO +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + # The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that # doxygen will group on one line in the generated HTML documentation. # @@ -1481,6 +1721,24 @@ TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + # Use this tag to change the font size of LaTeX formulas included as images in # the HTML documentation. When you change the font size after a successful # doxygen run you need to manually remove any form_*.png images from the HTML @@ -1490,19 +1748,14 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. -FORMULA_TRANSPARENT = YES +FORMULA_MACROFILE = # Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering +# https://www.mathjax.org) which uses client side JavaScript for the rendering # instead of using pre-rendered bitmaps. Use this if you do not have LaTeX # installed or if you want to formulas look prettier in the HTML output. When # enabled you may also need to install MathJax separately and configure the path @@ -1512,11 +1765,29 @@ FORMULA_TRANSPARENT = YES USE_MATHJAX = YES +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + # When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). # Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. # The default value is: HTML-CSS. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1529,22 +1800,29 @@ MATHJAX_FORMAT = HTML-CSS # MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax # Content Delivery Network so you can quickly see the result without installing # MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 # This tag requires that the tag USE_MATHJAX is set to YES. -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest +MATHJAX_RELPATH = # The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax # extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): # MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams # This tag requires that the tag USE_MATHJAX is set to YES. MATHJAX_EXTENSIONS = # The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces # of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an # example see the documentation. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1569,10 +1847,10 @@ MATHJAX_CODEFILE = # The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -SEARCHENGINE = YES +SEARCHENGINE = NO # When the SERVER_BASED_SEARCH tag is enabled the search engine will be -# implemented using a web server instead of a web client using Javascript. There +# implemented using a web server instead of a web client using JavaScript. There # are two flavors of web server based searching depending on the EXTERNAL_SEARCH # setting. When disabled, doxygen will generate a PHP script for searching and # an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing @@ -1591,7 +1869,8 @@ SERVER_BASED_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). +# Xapian (see: +# https://xapian.org/). # # See the section "External Indexing and Searching" for details. # The default value is: NO. @@ -1604,8 +1883,9 @@ EXTERNAL_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). See the section "External Indexing and -# Searching" for details. +# Xapian (see: +# https://xapian.org/). See the section "External Indexing and Searching" for +# details. # This tag requires that the tag SEARCHENGINE is set to YES. SEARCHENGINE_URL = @@ -1656,21 +1936,35 @@ LATEX_OUTPUT = latex # The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be # invoked. # -# Note that when enabling USE_PDFLATEX this option is only used for generating -# bitmaps for formulas in the HTML output, but not in the Makefile that is -# written to the output directory. -# The default file is: latex. +# Note that when not enabling USE_PDFLATEX the default is latex when enabling +# USE_PDFLATEX the default is pdflatex and when in the later case latex is +# chosen this is overwritten by pdflatex. For specific output languages the +# default can have been set differently, this depends on the implementation of +# the output language. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_CMD_NAME = latex # The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate # index for LaTeX. +# Note: This tag is used in the Makefile / make.bat. +# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file +# (.tex). # The default file is: makeindex. # This tag requires that the tag GENERATE_LATEX is set to YES. MAKEINDEX_CMD_NAME = makeindex +# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to +# generate index for LaTeX. In case there is no backslash (\) as first character +# it will be automatically added in the LaTeX code. +# Note: This tag is used in the generated output file (.tex). +# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat. +# The default value is: makeindex. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_MAKEINDEX_CMD = makeindex + # If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX # documents. This may be useful for small projects and may help to save some # trees in general. @@ -1700,29 +1994,31 @@ PAPER_TYPE = a4 EXTRA_PACKAGES = -# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the -# generated LaTeX document. The header should contain everything until the first -# chapter. If it is left blank doxygen will generate a standard header. See -# section "Doxygen usage" for information on how to let doxygen write the -# default header to a separate file. +# The LATEX_HEADER tag can be used to specify a user-defined LaTeX header for +# the generated LaTeX document. The header should contain everything until the +# first chapter. If it is left blank doxygen will generate a standard header. It +# is highly recommended to start with a default header using +# doxygen -w latex new_header.tex new_footer.tex new_stylesheet.sty +# and then modify the file new_header.tex. See also section "Doxygen usage" for +# information on how to generate the default header that doxygen normally uses. # -# Note: Only use a user-defined header if you know what you are doing! The -# following commands have a special meaning inside the header: $title, -# $datetime, $date, $doxygenversion, $projectname, $projectnumber, -# $projectbrief, $projectlogo. Doxygen will replace $title with the empty -# string, for the replacement values of the other commands the user is referred -# to HTML_HEADER. +# Note: Only use a user-defined header if you know what you are doing! +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. The following +# commands have a special meaning inside the header (and footer): For a +# description of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_HEADER = -# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the -# generated LaTeX document. The footer should contain everything after the last -# chapter. If it is left blank doxygen will generate a standard footer. See +# The LATEX_FOOTER tag can be used to specify a user-defined LaTeX footer for +# the generated LaTeX document. The footer should contain everything after the +# last chapter. If it is left blank doxygen will generate a standard footer. See # LATEX_HEADER for more information on how to generate a default footer and what -# special commands can be used inside the footer. -# -# Note: Only use a user-defined footer if you know what you are doing! +# special commands can be used inside the footer. See also section "Doxygen +# usage" for information on how to generate the default footer that doxygen +# normally uses. Note: Only use a user-defined footer if you know what you are +# doing! # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_FOOTER = @@ -1755,18 +2051,26 @@ LATEX_EXTRA_FILES = PDF_HYPERLINKS = YES -# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate -# the PDF file directly from the LaTeX files. Set this option to YES, to get a -# higher quality PDF documentation. +# If the USE_PDFLATEX tag is set to YES, doxygen will use the engine as +# specified with LATEX_CMD_NAME to generate the PDF file directly from the LaTeX +# files. Set this option to YES, to get a higher quality PDF documentation. +# +# See also section LATEX_CMD_NAME for selecting the engine. # The default value is: YES. # This tag requires that the tag GENERATE_LATEX is set to YES. USE_PDFLATEX = YES -# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode -# command to the generated LaTeX files. This will instruct LaTeX to keep running -# if errors occur, instead of asking the user for help. This option is also used -# when generating formulas in HTML. +# The LATEX_BATCHMODE tag ignals the behavior of LaTeX in case of an error. +# Possible values are: NO same as ERROR_STOP, YES same as BATCH, BATCH In batch +# mode nothing is printed on the terminal, errors are scrolled as if is +# hit at every error; missing files that TeX tries to input or request from +# keyboard input (\read on a not open input stream) cause the job to abort, +# NON_STOP In nonstop mode the diagnostic message will appear on the terminal, +# but there is no possibility of user interaction just like in batch mode, +# SCROLL In scroll mode, TeX will stop only for missing files to input or if +# keyboard input is necessary and ERROR_STOP In errorstop mode, TeX will stop at +# each error, asking for user intervention. # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. @@ -1779,24 +2083,22 @@ LATEX_BATCHMODE = NO LATEX_HIDE_INDICES = NO -# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source -# code with syntax highlighting in the LaTeX output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_SOURCE_CODE = NO - # The LATEX_BIB_STYLE tag can be used to specify the style to use for the # bibliography, e.g. plainnat, or ieeetr. See -# http://en.wikipedia.org/wiki/BibTeX and \cite for more info. +# https://en.wikipedia.org/wiki/BibTeX and \cite for more info. # The default value is: plain. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_BIB_STYLE = plain +# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) +# path from which the emoji images will be read. If a relative path is entered, +# it will be relative to the LATEX_OUTPUT directory. If left blank the +# LATEX_OUTPUT directory will be used. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_EMOJI_DIRECTORY = + #--------------------------------------------------------------------------- # Configuration options related to the RTF output #--------------------------------------------------------------------------- @@ -1836,9 +2138,9 @@ COMPACT_RTF = NO RTF_HYPERLINKS = NO -# Load stylesheet definitions from file. Syntax is similar to doxygen's config -# file, i.e. a series of assignments. You only have to provide replacements, -# missing definitions are set to their default value. +# Load stylesheet definitions from file. Syntax is similar to doxygen's +# configuration file, i.e. a series of assignments. You only have to provide +# replacements, missing definitions are set to their default value. # # See also section "Doxygen usage" for information on how to generate the # default style sheet that doxygen normally uses. @@ -1847,22 +2149,12 @@ RTF_HYPERLINKS = NO RTF_STYLESHEET_FILE = # Set optional variables used in the generation of an RTF document. Syntax is -# similar to doxygen's config file. A template extensions file can be generated -# using doxygen -e rtf extensionFile. +# similar to doxygen's configuration file. A template extensions file can be +# generated using doxygen -e rtf extensionFile. # This tag requires that the tag GENERATE_RTF is set to YES. RTF_EXTENSIONS_FILE = -# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code -# with syntax highlighting in the RTF output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_RTF is set to YES. - -RTF_SOURCE_CODE = NO - #--------------------------------------------------------------------------- # Configuration options related to the man page output #--------------------------------------------------------------------------- @@ -1934,6 +2226,13 @@ XML_OUTPUT = xml XML_PROGRAMLISTING = YES +# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include +# namespace members in file scope as well, matching the HTML output. +# The default value is: NO. +# This tag requires that the tag GENERATE_XML is set to YES. + +XML_NS_MEMB_FILE_SCOPE = NO + #--------------------------------------------------------------------------- # Configuration options related to the DOCBOOK output #--------------------------------------------------------------------------- @@ -1952,23 +2251,14 @@ GENERATE_DOCBOOK = NO DOCBOOK_OUTPUT = docbook -# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the -# program listings (including syntax highlighting and cross-referencing -# information) to the DOCBOOK output. Note that enabling this will significantly -# increase the size of the DOCBOOK output. -# The default value is: NO. -# This tag requires that the tag GENERATE_DOCBOOK is set to YES. - -DOCBOOK_PROGRAMLISTING = NO - #--------------------------------------------------------------------------- # Configuration options for the AutoGen Definitions output #--------------------------------------------------------------------------- # If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an -# AutoGen Definitions (see http://autogen.sf.net) file that captures the -# structure of the code including all documentation. Note that this feature is -# still experimental and incomplete at the moment. +# AutoGen Definitions (see https://autogen.sourceforge.net/) file that captures +# the structure of the code including all documentation. Note that this feature +# is still experimental and incomplete at the moment. # The default value is: NO. GENERATE_AUTOGEN_DEF = NO @@ -2047,7 +2337,8 @@ SEARCH_INCLUDES = NO # The INCLUDE_PATH tag can be used to specify one or more directories that # contain include files that are not input files but should be processed by the -# preprocessor. +# preprocessor. Note that the INCLUDE_PATH is not recursive, so the setting of +# RECURSIVE has no effect here. # This tag requires that the tag SEARCH_INCLUDES is set to YES. INCLUDE_PATH = @@ -2113,7 +2404,7 @@ TAGFILES = # tag file that is based on the input files it reads. See section "Linking to # external documentation" for more information about the usage of tag files. -GENERATE_TAGFILE = +GENERATE_TAGFILE = html/tagfile.xml # If the ALLEXTERNALS tag is set to YES, all external class will be listed in # the class index. If set to NO, only the inherited external classes will be @@ -2136,41 +2427,10 @@ EXTERNAL_GROUPS = YES EXTERNAL_PAGES = YES -# The PERL_PATH should be the absolute path and name of the perl script -# interpreter (i.e. the result of 'which perl'). -# The default file (with absolute path) is: /usr/bin/perl. - -PERL_PATH = /usr/bin/perl - #--------------------------------------------------------------------------- -# Configuration options related to the dot tool +# Configuration options related to diagram generator tools #--------------------------------------------------------------------------- -# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram -# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to -# NO turns the diagrams off. Note that this option also works with HAVE_DOT -# disabled, but it is recommended to install and use dot, since it yields more -# powerful graphs. -# The default value is: YES. - -CLASS_DIAGRAMS = NO - -# You can define message sequence charts within doxygen comments using the \msc -# command. Doxygen will then run the mscgen tool (see: -# http://www.mcternan.me.uk/mscgen/)) to produce the chart and insert it in the -# documentation. The MSCGEN_PATH tag allows you to specify the directory where -# the mscgen tool resides. If left empty the tool is assumed to be found in the -# default search path. - -MSCGEN_PATH = - -# You can include diagrams made with dia in doxygen documentation. Doxygen will -# then run dia to produce the diagram and insert it in the documentation. The -# DIA_PATH tag allows you to specify the directory where the dia binary resides. -# If left empty dia is assumed to be found in the default search path. - -DIA_PATH = - # If set to YES the inheritance and collaboration graphs will hide inheritance # and usage relations if the target is undocumented or is not a class. # The default value is: YES. @@ -2179,7 +2439,7 @@ HIDE_UNDOC_RELATIONS = YES # If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is # available from the path. This tool is part of Graphviz (see: -# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent +# https://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent # Bell Labs. The other options in this section have no effect if this option is # set to NO # The default value is: NO. @@ -2196,35 +2456,52 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. +# DOT_COMMON_ATTR is common attributes for nodes, edges and labels of +# subgraphs. When you want a differently looking font in the dot files that +# doxygen generates you can specify fontname, fontcolor and fontsize attributes. +# For details please see Node, +# Edge and Graph Attributes specification You need to make sure dot is able +# to find the font, which can be done by putting it in a standard location or by +# setting the DOTFONTPATH environment variable or by setting DOT_FONTPATH to the +# directory containing the font. Default graphviz fontsize is 14. +# The default value is: fontname=Helvetica,fontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTNAME = Helvetica +DOT_COMMON_ATTR = "fontname=Helvetica,fontsize=10" -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. +# DOT_EDGE_ATTR is concatenated with DOT_COMMON_ATTR. For elegant style you can +# add 'arrowhead=open, arrowtail=open, arrowsize=0.5'. Complete documentation about +# arrows shapes. +# The default value is: labelfontname=Helvetica,labelfontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTSIZE = 10 +DOT_EDGE_ATTR = "labelfontname=Helvetica,labelfontsize=10" -# By default doxygen will tell dot to use the default font as specified with -# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set -# the path where dot can find it using this tag. +# DOT_NODE_ATTR is concatenated with DOT_COMMON_ATTR. For view without boxes +# around nodes set 'shape=plain' or 'shape=plaintext' Shapes specification +# The default value is: shape=box,height=0.2,width=0.4. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_NODE_ATTR = "shape=box,height=0.2,width=0.4" + +# You can set the path where dot can find font specified with fontname in +# DOT_COMMON_ATTR and others dot attributes. # This tag requires that the tag HAVE_DOT is set to YES. DOT_FONTPATH = -# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for -# each documented class showing the direct and indirect inheritance relations. -# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO. +# If the CLASS_GRAPH tag is set to YES or GRAPH or BUILTIN then doxygen will +# generate a graph for each documented class showing the direct and indirect +# inheritance relations. In case the CLASS_GRAPH tag is set to YES or GRAPH and +# HAVE_DOT is enabled as well, then dot will be used to draw the graph. In case +# the CLASS_GRAPH tag is set to YES and HAVE_DOT is disabled or if the +# CLASS_GRAPH tag is set to BUILTIN, then the built-in generator will be used. +# If the CLASS_GRAPH tag is set to TEXT the direct and indirect inheritance +# relations will be shown as texts / links. +# Possible values are: NO, YES, TEXT, GRAPH and BUILTIN. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. CLASS_GRAPH = YES @@ -2238,7 +2515,8 @@ CLASS_GRAPH = YES COLLABORATION_GRAPH = YES # If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for -# groups, showing the direct groups dependencies. +# groups, showing the direct groups dependencies. See also the chapter Grouping +# in the manual. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2261,10 +2539,32 @@ UML_LOOK = NO # but if the number exceeds 15, the total amount of fields shown is limited to # 10. # Minimum value: 0, maximum value: 100, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. +# This tag requires that the tag UML_LOOK is set to YES. UML_LIMIT_NUM_FIELDS = 10 +# If the DOT_UML_DETAILS tag is set to NO, doxygen will show attributes and +# methods without types and arguments in the UML graphs. If the DOT_UML_DETAILS +# tag is set to YES, doxygen will add type and arguments for attributes and +# methods in the UML graphs. If the DOT_UML_DETAILS tag is set to NONE, doxygen +# will not generate fields with class member information in the UML graphs. The +# class diagrams will look similar to the default class diagrams but using UML +# notation for the relationships. +# Possible values are: NO, YES and NONE. +# The default value is: NO. +# This tag requires that the tag UML_LOOK is set to YES. + +DOT_UML_DETAILS = NO + +# The DOT_WRAP_THRESHOLD tag can be used to set the maximum number of characters +# to display on a single line. If the actual line length exceeds this threshold +# significantly it will wrapped across multiple lines. Some heuristics are apply +# to avoid ugly line breaks. +# Minimum value: 0, maximum value: 1000, default value: 17. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_WRAP_THRESHOLD = 17 + # If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and # collaboration graphs will show the relations between templates and their # instances. @@ -2331,10 +2631,17 @@ GRAPHICAL_HIERARCHY = YES DIRECTORY_GRAPH = YES +# The DIR_GRAPH_MAX_DEPTH tag can be used to limit the maximum number of levels +# of child directories generated in directory dependency graphs by dot. +# Minimum value: 1, maximum value: 25, default value: 1. +# This tag requires that the tag DIRECTORY_GRAPH is set to YES. + +DIR_GRAPH_MAX_DEPTH = 1 + # The DOT_IMAGE_FORMAT tag can be used to set the image format of the images # generated by dot. For an explanation of the image formats see the section # output formats in the documentation of the dot tool (Graphviz (see: -# http://www.graphviz.org/)). +# https://www.graphviz.org/)). # Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order # to make the SVG files visible in IE 9+ (other browsers do not have this # requirement). @@ -2344,7 +2651,7 @@ DIRECTORY_GRAPH = YES # The default value is: png. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_IMAGE_FORMAT = png +DOT_IMAGE_FORMAT = svg # If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to # enable generation of interactive SVG images that allow zooming and panning. @@ -2356,7 +2663,7 @@ DOT_IMAGE_FORMAT = png # The default value is: NO. # This tag requires that the tag HAVE_DOT is set to YES. -INTERACTIVE_SVG = NO +INTERACTIVE_SVG = YES # The DOT_PATH tag can be used to specify the path where the dot tool can be # found. If left blank, it is assumed the dot tool can be found in the path. @@ -2371,11 +2678,12 @@ DOT_PATH = DOTFILE_DIRS = -# The MSCFILE_DIRS tag can be used to specify one or more directories that -# contain msc files that are included in the documentation (see the \mscfile -# command). +# You can include diagrams made with dia in doxygen documentation. Doxygen will +# then run dia to produce the diagram and insert it in the documentation. The +# DIA_PATH tag allows you to specify the directory where the dia binary resides. +# If left empty dia is assumed to be found in the default search path. -MSCFILE_DIRS = +DIA_PATH = # The DIAFILE_DIRS tag can be used to specify one or more directories that # contain dia files that are included in the documentation (see the \diafile @@ -2384,13 +2692,18 @@ MSCFILE_DIRS = DIAFILE_DIRS = # When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the -# path where java can find the plantuml.jar file. If left blank, it is assumed -# PlantUML is not used or called during a preprocessing step. Doxygen will -# generate a warning when it encounters a \startuml command in this case and -# will not generate output for the diagram. +# path where java can find the plantuml.jar file or to the filename of jar file +# to be used. If left blank, it is assumed PlantUML is not used or called during +# a preprocessing step. Doxygen will generate a warning when it encounters a +# \startuml command in this case and will not generate output for the diagram. PLANTUML_JAR_PATH = +# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a +# configuration file for plantuml. + +PLANTUML_CFG_FILE = + # When using plantuml, the specified paths are searched for files specified by # the !include statement in a plantuml block. @@ -2420,18 +2733,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support @@ -2444,14 +2745,34 @@ DOT_MULTI_TARGETS = NO # If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page # explaining the meaning of the various boxes and arrows in the dot generated # graphs. +# Note: This tag requires that UML_LOOK isn't set, i.e. the doxygen internal +# graphical representation for inheritance and collaboration diagrams is used. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. GENERATE_LEGEND = YES -# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot +# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate # files that are used to generate the various graphs. +# +# Note: This setting is not only used for dot files but also for msc temporary +# files. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. DOT_CLEANUP = YES + +# You can define message sequence charts within doxygen comments using the \msc +# command. If the MSCGEN_TOOL tag is left empty (the default), then doxygen will +# use a built-in version of mscgen tool to produce the charts. Alternatively, +# the MSCGEN_TOOL tag can also specify the name an external tool. For instance, +# specifying prog as the value, doxygen will call the tool as prog -T +# -o . The external tool should support +# output file formats "png", "eps", "svg", and "ismap". + +MSCGEN_TOOL = + +# The MSCFILE_DIRS tag can be used to specify one or more directories that +# contain msc files that are included in the documentation (see the \mscfile +# command). + +MSCFILE_DIRS = diff --git a/docs/index.rst b/docs/index.rst index 30ef672f84..89a5e3e836 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -8,31 +8,38 @@ Composable Kernel User Guide ******************************************************************** -The Composable Kernel (CK) library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages like HIP C++. This document contains instructions for installing, using, and contributing to the Composable Kernel project. To learn more see :ref:`what-is-ck`. +The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. -The CK documentation is structured as follows: +The Composable Kernel repository is located at `https://github.com/ROCm/composable_kernel `_. .. grid:: 2 :gutter: 3 - .. grid-item-card:: Installation + .. grid-item-card:: Install - * :ref:`docker-hub` + * :doc:`Composable Kernel prerequisites <./install/Composable-Kernel-prerequisites>` + * :doc:`Build and install Composable Kernel <./install/Composable-Kernel-install>` + * :doc:`Build and install Composable Kernel on a Docker image <./install/Composable-Kernel-Docker>` .. grid-item-card:: Conceptual - * :ref:`what-is-ck` + * :doc:`Composable Kernel structure <./conceptual/Composable-Kernel-structure>` + * :doc:`Composable Kernel mathematical basis <./conceptual/Composable-Kernel-math>` - .. grid-item-card:: API reference + .. grid-item-card:: Tutorials - * :ref:`supported-primitives` - * :ref:`api-reference` - * :ref:`wrapper` + * :doc:`Composable Kernel examples and tests <./tutorial/Composable-Kernel-examples>` - .. grid-item-card:: Tutorial - - * :ref:`hello-world` + .. grid-item-card:: Reference + * :doc:`Composable Kernel supported scalar types <./reference/Composable_Kernel_supported_scalar_types>` + * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` + * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` + * :ref:`wrapper` + * :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>` + * :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>` + * :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>` + To contribute to the documentation refer to `Contributing to ROCm `_. You can find licensing information on the `Licensing `_ page. diff --git a/docs/install/Composable-Kernel-Docker.rst b/docs/install/Composable-Kernel-Docker.rst new file mode 100644 index 0000000000..d40cc2bff5 --- /dev/null +++ b/docs/install/Composable-Kernel-Docker.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel docker files + :keywords: composable kernel, CK, ROCm, API, docker + +.. _docker-hub: + +******************************************************************** +Composable Kernel Docker containers +******************************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The images also contain `ROCm `_, `CMake `_, and the `ROCm LLVM compiler infrastructure `_. + +Composable Kernel Docker images are named according to their operating system and ROCm version. For example, a Docker image named ``ck_ub22.04_rocm6.3`` would correspond to an Ubuntu 22.04 image with ROCm 6.3. + diff --git a/docs/install/Composable-Kernel-install.rst b/docs/install/Composable-Kernel-install.rst new file mode 100644 index 0000000000..61b1fe0fcb --- /dev/null +++ b/docs/install/Composable-Kernel-install.rst @@ -0,0 +1,72 @@ +.. meta:: + :description: Composable Kernel build and install + :keywords: composable kernel, CK, ROCm, API, documentation, install + +****************************************************** +Building and installing Composable Kernel with CMake +****************************************************** + +Before you begin, clone the `Composable Kernel GitHub repository `_ and create a ``build`` directory in its root: + +.. code:: shell + + git clone https://github.com/ROCm/composable_kernel.git + cd composable_kernel + mkdir build + +Change directory to the ``build`` directory and generate the makefile using the ``cmake`` command. Two build options are required: + +* ``CMAKE_PREFIX_PATH``: The ROCm installation path. ROCm is installed in ``/opt/rocm`` by default. +* ``CMAKE_CXX_COMPILER``: The path to the Clang compiler. Clang is found at ``/opt/rocm/llvm/bin/clang++`` by default. + + +.. code:: shell + + cd build + cmake ../. -D CMAKE_PREFIX_PATH="/opt/rocm" -D CMAKE_CXX_COMPILER="/opt/rocm/llvm/bin/clang++" [-D [-D] ...] + + +Other build options are: + +* ``DISABLE_DL_KERNELS``: Set this to "ON" to not build deep learning (DL) and data parallel primitive (DPP) instances. + + .. note:: + + DL and DPP instances are useful on architectures that don't support XDL or WMMA. + +* ``CK_USE_FP8_ON_UNSUPPORTED_ARCH``: Set to ``ON`` to build FP8 data type instances on gfx90a without native FP8 support. +* ``GPU_TARGETS``: Target architectures. Target architectures in this list must all be different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx90a"``. This option is required to build tests and examples. +* ``GPU_ARCHS``: Target architectures. Target architectures in this list are not limited to different versions of the same architectures. Enclose the list of targets in quotation marks. Separate multiple targets with semicolons (``;``). For example, ``cmake -D GPU_TARGETS="gfx908;gfx1100"``. +* ``CMAKE_BUILD_TYPE``: The build type. Can be ``None``, ``Release``, ``Debug``, ``RelWithDebInfo``, or ``MinSizeRel``. CMake will use ``Release`` by default. + +.. Note:: + + If neither ``GPU_TARGETS`` nor ``GPU_ARCHS`` is specified, Composable Kernel will be built for all targets supported by the compiler. + +Build Composable Kernel using the generated makefile. This will build the library, the examples, and the tests, and save them to ``bin``. + +.. code:: shell + + make -j20 + +The ``-j`` option speeds up the build by using multiple threads in parallel. For example, ``-j20`` uses twenty threads in parallel. On average, each thread will use 2GB of memory. Make sure that the number of threads you use doesn't exceed the available memory in your system. + +Using ``-j`` alone will launch an unlimited number of threads and is not recommended. + +Install the Composable Kernel library: + +.. code:: shell + + make install + +After running ``make install``, the Composable Kernel files will be saved to the following locations: + +* Library files: ``/opt/rocm/lib/`` +* Header files: ``/opt/rocm/include/ck/`` and ``/opt/rocm/include/ck_tile/`` +* Examples, tests, and ckProfiler: ``/opt/rocm/bin/`` + +For information about ckProfiler, see `the ckProfiler readme file `_. + +For information about running the examples and tests, see :doc:`Composable Kernel examples and tests <../tutorial/Composable-Kernel-examples>`. + + diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst new file mode 100644 index 0000000000..9dc082599a --- /dev/null +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -0,0 +1,32 @@ +.. meta:: + :description: Composable Kernel prerequisites + :keywords: composable kernel, CK, ROCm, API, documentation, prerequisites + +****************************************************** +Composable Kernel prerequisites +****************************************************** + +Docker images that include all the required prerequisites for building Composable Kernel are available on `Docker Hub `_. + +The following prerequisites are required to build and install Composable Kernel: + +* cmake +* hip-rocclr +* iputils-ping +* jq +* libelf-dev +* libncurses5-dev +* libnuma-dev +* libpthread-stubs0-dev +* llvm-amdgpu +* mpich +* net-tools +* python3 +* python3-dev +* python3-pip +* redis +* rocm-llvm-dev +* zlib1g-dev +* libzstd-dev +* openssh-server +* clang-format-18 diff --git a/docs/install/dockerhub.rst b/docs/install/dockerhub.rst deleted file mode 100644 index 87eb5a4f81..0000000000 --- a/docs/install/dockerhub.rst +++ /dev/null @@ -1,101 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _docker-hub: - -******************************************************************** -CK Docker Hub -******************************************************************** - -Why do I need this? -=================== - -To make things simpler, and bring Composable Kernel and its dependencies together, -docker images can be found on `Docker Hub `_. Docker images provide a complete image of the OS, the Composable Kernel library, and its dependencies in a single downloadable file. - -Refer to `Docker Overview `_ for more information on Docker images and containers. - -Which image is right for me? -============================ - -The image naming includes information related to the docker image. -For example ``ck_ub20.04_rocm6.0`` indicates the following: - -* ``ck`` - made for running Composable Kernel; -* ``ub20.04`` - based on Ubuntu 20.04; -* ``rocm6.0`` - ROCm platform version 6.0. - -Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. Use the ``docker pull`` command to download the file:: - - docker pull rocm/composable_kernel:ck_ub20.04_rocm6.0 - - -What is inside the image? -------------------------- - -The docker images have everything you need for running CK including: - -* `ROCm `_ -* `CMake `_ -* `Compiler `_ -* `Composable Kernel library `_ - -Running the docker container -============================ - -After downloading the docker image, you can start the container using one of a number of commands. Start with the ``docker run`` command as shown below:: - - docker run \ - -it \ - --privileged \ - --group-add sudo \ - -w /root/workspace \ - -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ - rocm/composable_kernel:ck_ub20.04_rocm6.0 \ - /bin/bash - -After starting the bash shell, the docker container current folder is `~/workspace`. The library path is ``~/workspace/composable_kernel``. Navigate to the library to begin the tutorial as explained in :ref:`hello-world`: - -.. note:: - - If your current folder is different from `${HOME}`, adjust the line ``-v ${HOME}:/root/workspace`` in the ``docker run`` command to fit your folder structure. - -Stop and restart the docker image -================================= - -After finishing the tutorial, or just when you have completed your work session, you can close the docker container, or stop the docker container to restart it at another time. Closing the docker container means that it is still in the active state, and can be resumed from where you left it. Stopping the container closes it, and returns the image to its initial state. - -Use the ``Ctrl-D`` option to exit the container, while leaving it active, so you can return to the container in its current state to resume the tutorial, or pickup your project where you left off. - -To restart the active container use the ``docker exec`` command to specify the container name and options as follows:: - - docker exec -it bash - -Where: - -* `exec` is the docker command -* `-it` is the interactive option for `exec` -* `` specifies an active container on the system -* `bash` specifies the command to run in the interactive shell - -.. note:: - - You can use the ``docker container ls`` command to list the active containers on the system. - -To start a container from the image, use the ``docker start`` command:: - - docker start - -Then use the docker exec command as shown above to start the bash shell. - -Use the ``docker stop`` command to stop the container and restore the image to its initial state:: - - docker stop - -Editing the docker image -======================= - -If you want to customize the docker image, edit the -`Dockerfile `_ -from the GitHub repository to suit your needs. diff --git a/docs/reference/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst deleted file mode 100644 index 0d2d41c1eb..0000000000 --- a/docs/reference/API_Reference_Guide.rst +++ /dev/null @@ -1,48 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _api-reference: - -******************************************************************** -API reference guide -******************************************************************** - - -This document contains details of the APIs for the Composable Kernel (CK) library and introduces -some of the key design principles that are used to write new classes that extend CK functionality. - -================= -CK Datatypes -================= - ------------------ -DeviceMem ------------------ - -.. doxygenstruct:: DeviceMem - ---------------------------- -Kernels For Flashattention ---------------------------- - -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists -the classes that are used in the CK GPU implementation of Flashattention. - -**Gridwise classes** - -.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle - -**Blockwise classes** - -.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 - -.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 - -.. doxygenstruct:: ck::BlockwiseSoftmax - -**Threadwise classes** - -.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic - -.. bibliography:: diff --git a/docs/reference/wrapper.rst b/docs/reference/Composable-Kernel-wrapper.rst similarity index 88% rename from docs/reference/wrapper.rst rename to docs/reference/Composable-Kernel-wrapper.rst index 190fbcd445..4baa8d2b64 100644 --- a/docs/reference/wrapper.rst +++ b/docs/reference/Composable-Kernel-wrapper.rst @@ -1,20 +1,15 @@ .. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation + :description: Composable Kernel wrapper + :keywords: composable kernel, CK, ROCm, API, wrapper .. _wrapper: ******************************************************************** -Wrapper +Composable Kernel wrapper ******************************************************************** -------------------------------------- -Description -------------------------------------- - -The CK library provides a lightweight wrapper for more complex operations implemented in -the library. +The Composable Kernel library provides a lightweight wrapper to simplify the more complex operations. Example: diff --git a/docs/reference/Composable_Kernel_custom_types.rst b/docs/reference/Composable_Kernel_custom_types.rst new file mode 100644 index 0000000000..863d4131b9 --- /dev/null +++ b/docs/reference/Composable_Kernel_custom_types.rst @@ -0,0 +1,39 @@ +.. meta:: + :description: Composable Kernel supported custom types + :keywords: composable kernel, custom, data types, support, CK, ROCm + +****************************************************** +Composable Kernel custom data types +****************************************************** + +Composable Kernel supports the use of custom types that provide a way to implement specialized numerical formats. + +To use custom types, a C++ type that implements the necessary operations for tensor computations needs to be created. These should include: + +* Constructors and initialization methods +* Arithmetic operators if the type will be used in computational operations +* Any conversion functions needed to interface with other parts of an application + +For example, to create a complex half-precision type: + +.. code:: cpp + + struct complex_half_t + { + half_t real; + half_t img; + }; + + struct complex_half_t + { + using type = half_t; + type real; + type img; + + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + +Custom types can be particularly useful for specialized applications such as complex number arithmetic, +custom quantization schemes, or domain-specific number representations. + diff --git a/docs/reference/Composable_Kernel_supported_scalar_types.rst b/docs/reference/Composable_Kernel_supported_scalar_types.rst new file mode 100644 index 0000000000..7ea1a9eaeb --- /dev/null +++ b/docs/reference/Composable_Kernel_supported_scalar_types.rst @@ -0,0 +1,69 @@ +.. meta:: + :description: Composable Kernel supported scalar types + :keywords: composable kernel, scalar, data types, support, CK, ROCm + +*************************************************** +Composable Kernel supported scalar data types +*************************************************** + +The Composable Kernel library provides support for the following scalar data types: + +.. list-table:: + :header-rows: 1 + :widths: 25 15 60 + + * - Type + - Bit Width + - Description + + * - ``double`` + - 64-bit + - Standard IEEE 754 double precision floating point + + * - ``float`` + - 32-bit + - Standard IEEE 754 single precision floating point + + * - ``int32_t`` + - 32-bit + - Standard signed 32-bit integer + + * - ``int8_t`` + - 8-bit + - Standard signed 8-bit integer + + * - ``uint8_t`` + - 8-bit + - Standard unsigned 8-bit integer + + * - ``bool`` + - 1-bit + - Boolean type + + * - ``ck::half_t`` + - 16-bit + - IEEE 754 half precision floating point with 5 exponent bits, 10 mantissa bits, and 1 sign bit + + * - ``ck::bhalf_t`` + - 16-bit + - Brain floating point with 8 exponent bits, 7 mantissa bits, and 1 sign bit + + * - ``ck::f8_t`` + - 8-bit + - 8-bit floating point (E4M3 format) with 4 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf8_t`` + - 8-bit + - 8-bit brain floating point (E5M2 format) with 5 exponent bits, 2 mantissa bits, and 1 sign bit + + * - ``ck::f4_t`` + - 4-bit + - 4-bit floating point format (E2M1 format) with 2 exponent bits, 1 mantissa bit, and 1 sign bit + + * - ``ck::f6_t`` + - 6-bit + - 6-bit floating point format (E2M3 format) with 2 exponent bits, 3 mantissa bits, and 1 sign bit + + * - ``ck::bf6_t`` + - 6-bit + - 6-bit brain floating point format (E3M2 format) with 3 exponent bits, 2 mantissa bits, and 1 sign bit \ No newline at end of file diff --git a/docs/reference/Composable_Kernel_vector_utilities.rst b/docs/reference/Composable_Kernel_vector_utilities.rst new file mode 100644 index 0000000000..3103653191 --- /dev/null +++ b/docs/reference/Composable_Kernel_vector_utilities.rst @@ -0,0 +1,16 @@ +.. meta:: + :description: Composable Kernel supported precision types and custom type support + :keywords: composable kernel, precision, data types, ROCm + +****************************************************** +Composable Kernel vector template utilities +****************************************************** + +Composable Kernel includes template utilities for creating vector types with customizable widths. These template utilities also flatten nested vector types into a single, wider vector, preventing the creation of vectors of vectors. + +Vectors composed of supported scalar and custom types can be created with the ``ck::vector_type`` template. + +For example, ``ck::vector_type`` creates a vector composed of four floats and ``ck::vector_type`` creates a vector composed of eight half-precision scalars. + +For vector operations to be valid, the underlying types must be either a :doc:`supported scalar type ` or :doc:`a custom type ` that implements the required operations. + diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 533b81cd39..2ef3383d84 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -3,34 +3,43 @@ defaults: root: index subtrees: -- caption: Conceptual - entries: - - file: conceptual/what-is-ck.rst - title: What is Composable Kernel? - - caption: Install entries: - - file: install/dockerhub.rst - title: Docker Hub + - file: install/Composable-Kernel-prerequisites.rst + title: Composable Kernel prerequisites + - file: install/Composable-Kernel-install.rst + title: Build and install Composable Kernel + - file: install/Composable-Kernel-Docker.rst + title: Composable Kernel Docker images -- caption: CK API Reference +- caption: Conceptual entries: - - file: reference/Supported_Primitives_Guide.rst - title: Supported Primitives - - file: reference/API_Reference_Guide.rst - title: API Reference - - file: reference/wrapper.rst - title: Wrapper + - file: conceptual/Composable-Kernel-structure.rst + title: Composable Kernel structure + - file: conceptual/Composable-Kernel-math.rst + title: Composable Kernel mathematical basis - caption: Tutorial entries: - - file: tutorial/tutorial_hello_world.rst - title: Hello World Tutorial + - file: tutorial/Composable-Kernel-examples.rst + title: Composable Kernel examples + +- caption: Reference + entries: + - file: reference/Composable_Kernel_supported_scalar_types.rst + title: Composable Kernel scalar types + - file: reference/Composable_Kernel_custom_types.rst + title: Composable Kernel custom types + - file: reference/Composable_Kernel_vector_utilities.rst + title: Composable Kernel vector utilities + - file: reference/Composable-Kernel-wrapper.rst + title: Composable Kernel wrapper + - file: doxygen/html/annotated.rst + title: Composable Kernel class list - caption: About entries: - file: Contributors_Guide.rst - title: Contributing to CK + title: Contributing to Composable Kernel - file: license.rst title: License - \ No newline at end of file diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 2fcf3b3935..beedb4e867 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.18.1 -sphinxcontrib-bibtex==2.6.3 +rocm-docs-core[api_reference]==1.20.1 +sphinxcontrib-bibtex==2.6.5 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 12572d400e..e8aa02aa01 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -6,68 +6,79 @@ # accessible-pygments==0.0.5 # via pydata-sphinx-theme -alabaster==0.7.16 +alabaster==1.0.0 # via sphinx asttokens==3.0.0 # via stack-data -attrs==24.3.0 +attrs==25.3.0 # via # jsonschema # jupyter-cache # referencing -babel==2.15.0 +babel==2.17.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.12.3 +beautifulsoup4==4.13.4 # via pydata-sphinx-theme -breathe==4.35.0 +breathe==4.36.0 # via rocm-docs-core -certifi==2024.7.4 +certifi==2025.1.31 # via requests -cffi==1.16.0 +cffi==1.17.1 # via # cryptography # pynacl -charset-normalizer==3.3.2 +charset-normalizer==3.4.1 # via requests -click==8.1.7 +click==8.1.8 # via + # click-log + # doxysphinx # jupyter-cache # sphinx-external-toc +click-log==0.4.0 + # via doxysphinx comm==0.2.2 # via ipykernel -cryptography==43.0.0 +contourpy==1.3.2 + # via matplotlib +cryptography==44.0.2 # via pyjwt -debugpy==1.8.12 +cycler==0.12.1 + # via matplotlib +debugpy==1.8.14 # via ipykernel -decorator==5.1.1 +decorator==5.2.1 # via ipython -deprecated==1.2.14 +deprecated==1.2.18 # via pygithub docutils==0.21.2 # via - # breathe # myst-parser # pybtex-docutils # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +doxysphinx==3.3.12 + # via rocm-docs-core exceptiongroup==1.2.2 # via ipython -executing==2.1.0 +executing==2.2.0 # via stack-data -fastjsonschema==2.20.0 +fastjsonschema==2.21.1 # via # nbformat # rocm-docs-core -gitdb==4.0.11 +fonttools==4.57.0 + # via matplotlib +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via rocm-docs-core -greenlet==3.1.1 +greenlet==3.2.1 # via sqlalchemy -idna==3.7 +idna==3.10 # via requests imagesize==1.4.1 # via sphinx @@ -77,13 +88,13 @@ importlib-metadata==8.6.1 # myst-nb ipykernel==6.29.5 # via myst-nb -ipython==8.31.0 +ipython==8.35.0 # via # ipykernel # myst-nb jedi==0.19.2 # via ipython -jinja2==3.1.4 +jinja2==3.1.6 # via # myst-parser # sphinx @@ -103,25 +114,35 @@ jupyter-core==5.7.2 # jupyter-client # nbclient # nbformat +kiwisolver==1.4.8 + # via matplotlib latexcodec==3.0.0 # via pybtex +libsass==0.22.0 + # via doxysphinx +lxml==5.2.1 + # via doxysphinx markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 +matplotlib==3.10.1 + # via doxysphinx matplotlib-inline==0.1.7 # via # ipykernel # ipython -mdit-py-plugins==0.4.1 +mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-nb==1.1.2 +mpire==2.10.2 + # via doxysphinx +myst-nb==1.2.0 # via rocm-docs-core -myst-parser==3.0.1 +myst-parser==4.0.1 # via myst-nb nbclient==0.10.2 # via @@ -134,26 +155,34 @@ nbformat==5.10.4 # nbclient nest-asyncio==1.6.0 # via ipykernel -packaging==24.1 +numpy==1.26.4 + # via + # contourpy + # doxysphinx + # matplotlib +packaging==25.0 # via # ipykernel + # matplotlib # pydata-sphinx-theme # sphinx parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.3.6 +pillow==11.2.1 + # via matplotlib +platformdirs==4.3.7 # via jupyter-core -prompt-toolkit==3.0.50 +prompt-toolkit==3.0.51 # via ipython -psutil==6.1.1 +psutil==7.0.0 # via ipykernel ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data -pybtex==0.24.0 +pybtex==0.25.1 # via # pybtex-docutils # sphinxcontrib-bibtex @@ -165,21 +194,30 @@ pydata-sphinx-theme==0.15.4 # via # rocm-docs-core # sphinx-book-theme -pygithub==2.3.0 +pygithub==2.6.1 # via rocm-docs-core -pygments==2.18.0 +pygments==2.19.1 # via # accessible-pygments # ipython + # mpire # pydata-sphinx-theme # sphinx -pyjwt[crypto]==2.8.0 +pyjson5==1.6.8 + # via doxysphinx +pyjwt[crypto]==2.10.1 # via pygithub pynacl==1.5.0 # via pygithub +pyparsing==3.2.3 + # via + # doxysphinx + # matplotlib python-dateutil==2.9.0.post0 - # via jupyter-client -pyyaml==6.0.1 + # via + # jupyter-client + # matplotlib +pyyaml==6.0.2 # via # jupyter-cache # myst-nb @@ -187,11 +225,11 @@ pyyaml==6.0.1 # pybtex # rocm-docs-core # sphinx-external-toc -pyzmq==26.2.0 +pyzmq==26.4.0 # via # ipykernel # jupyter-client -referencing==0.36.1 +referencing==0.36.2 # via # jsonschema # jsonschema-specifications @@ -199,23 +237,21 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.18.1 +rocm-docs-core[api-reference]==1.20.1 # via -r requirements.in -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -six==1.16.0 - # via - # pybtex - # python-dateutil -smmap==5.0.1 +six==1.17.0 + # via python-dateutil +smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 # via sphinx -soupsieve==2.5 +soupsieve==2.7 # via beautifulsoup4 -sphinx==7.4.7 +sphinx==8.1.3 # via # breathe # myst-nb @@ -228,19 +264,19 @@ sphinx==7.4.7 # sphinx-external-toc # sphinx-notfound-page # sphinxcontrib-bibtex -sphinx-book-theme==1.1.3 +sphinx-book-theme==1.1.4 # via rocm-docs-core sphinx-copybutton==0.5.2 # via rocm-docs-core -sphinx-design==0.6.0 +sphinx-design==0.6.1 # via rocm-docs-core sphinx-external-toc==1.0.1 # via rocm-docs-core -sphinx-notfound-page==1.0.3 +sphinx-notfound-page==1.1.0 # via rocm-docs-core sphinxcontrib-applehelp==2.0.0 # via sphinx -sphinxcontrib-bibtex==2.6.3 +sphinxcontrib-bibtex==2.6.5 # via -r requirements.in sphinxcontrib-devhelp==2.0.0 # via sphinx @@ -252,18 +288,20 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sqlalchemy==2.0.37 +sqlalchemy==2.0.40 # via jupyter-cache stack-data==0.6.3 # via ipython tabulate==0.9.0 # via jupyter-cache -tomli==2.0.1 +tomli==2.2.1 # via sphinx tornado==6.4.2 # via # ipykernel # jupyter-client +tqdm==4.67.1 + # via mpire traitlets==5.14.3 # via # comm @@ -274,21 +312,22 @@ traitlets==5.14.3 # matplotlib-inline # nbclient # nbformat -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via + # beautifulsoup4 # ipython # myst-nb # pydata-sphinx-theme # pygithub # referencing # sqlalchemy -urllib3==2.2.2 +urllib3==2.4.0 # via # pygithub # requests wcwidth==0.2.13 # via prompt-toolkit -wrapt==1.16.0 +wrapt==1.17.2 # via deprecated zipp==3.21.0 # via importlib-metadata diff --git a/docs/tutorial/Composable-Kernel-examples.rst b/docs/tutorial/Composable-Kernel-examples.rst new file mode 100644 index 0000000000..62422d6f15 --- /dev/null +++ b/docs/tutorial/Composable-Kernel-examples.rst @@ -0,0 +1,40 @@ +.. meta:: + :description: Composable Kernel examples and tests + :keywords: composable kernel, CK, ROCm, API, examples, tests + +******************************************************************** +Composable Kernel examples and tests +******************************************************************** + +After :doc:`building and installing Composable Kernel <../install/Composable-Kernel-install>`, the examples and tests will be moved to ``/opt/rocm/bin/``. + +All tests have the prefix ``test`` and all examples have the prefix ``example``. + +Use ``ctest`` with no arguments to run all examples and tests, or use ``ctest -R`` to run a single test. For example: + +.. code:: shell + + ctest -R test_gemm_fp16 + +Examples can be run individually as well. For example: + +.. code:: shell + + ./bin/example_gemm_xdl_fp16 1 1 1 + +For instructions on how to run individual examples and tests, see their README files in the |example|_ and |test|_ GitHub folders. + +To run smoke tests, use ``make smoke``. + +To run regression tests, use ``make regression``. + +In general, tests that run for under thirty seconds are included in the smoke tests and tests that run for over thirty seconds are included in the regression tests. + +.. |example| replace:: ``example`` +.. _example: https://github.com/ROCm/composable_kernel/tree/develop/example + +.. |client_example| replace:: ``client_example`` +.. _client_example: https://github.com/ROCm/composable_kernel/tree/develop/client_example + +.. |test| replace:: ``test`` +.. _test: https://github.com/ROCm/composable_kernel/tree/develop/test \ No newline at end of file diff --git a/docs/tutorial/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst deleted file mode 100644 index c31460785b..0000000000 --- a/docs/tutorial/tutorial_hello_world.rst +++ /dev/null @@ -1,165 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _hello-world: - -******************************************************************** -Hello World Tutorial -******************************************************************** - -This tutorial is for engineers dealing with artificial intelligence and machine learning who -would like to optimize pipelines and improve performance using the Composable -Kernel (CK) library. This tutorial provides an introduction to the CK library. You will build the library and run some examples using a "Hello World" example. - -Description -=========== - -Modern AI technology solves more and more problems in a variety of fields, but crafting fast and -efficient workflows is still challenging. CK can make the AI workflow fast -and efficient. CK is a collection of optimized AI operator kernels with tools to create -new kernels. The library has components required for modern neural network architectures -including matrix multiplication, convolution, contraction, reduction, attention modules, a variety of activation functions, and fused operators. - -CK library acceleration features are based on: - -* Layered structure -* Tile-based computation model -* Tensor coordinate transformation -* Hardware acceleration use -* Support of low precision data types including fp16, bf16, int8 and int4 - -If you need more technical details and benchmarking results read the following -`blog post `_. - -To download the library visit the `composable_kernel repository `_. - -Hardware targets -================ - -CK library fully supports `gfx908` and `gfx90a` GPU architectures, while only some operators are -supported for `gfx1030` devices. Check your hardware to determine the target GPU architecture. - -========== ========= -GPU Target AMD GPU -========== ========= -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT -========== ========= - -There are also `cloud options `_ you can find if -you don't have an AMD GPU at hand. - -Build the library -================= - -This tutorial is based on the use of docker images as explained in :ref:`docker-hub`. Download a docker image suitable for your OS and ROCm release, run or start the docker container, and then resume the tutorial from this point. - -.. note:: - - You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. - -Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: - - cd composable_kernel/ - -Create and change to a ``build`` directory:: - - mkdir build && cd build - -The previous section discussed supported GPU architecture. Once you decide which hardware targets are needed, run CMake using the ``GPU_TARGETS`` flag:: - - cmake \ - -D CMAKE_PREFIX_PATH=/opt/rocm \ - -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ - -D CMAKE_CXX_FLAGS="-O3" \ - -D CMAKE_BUILD_TYPE=Release \ - -D BUILD_DEV=OFF \ - -D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. - -If everything goes well the CMake command will return:: - - -- Configuring done - -- Generating done - -- Build files have been written to: "/root/workspace/composable_kernel/build" - -Finally, you can build examples and tests:: - - make -j examples tests - -When complete you should see:: - - Scanning dependencies of target tests - [100%] Built target tests - -Run examples and tests -====================== - -Examples are listed as test cases as well, so you can run all examples and tests with:: - - ctest - -You can check the list of all tests by running:: - - ctest -N - -You can also run examples separately as shown in the following example execution:: - - ./bin/example_gemm_xdl_fp16 1 1 1 - -The arguments ``1 1 1`` mean that you want to run this example in the mode: verify results with CPU, initialize matrices with integers, and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. - -If you have a device based on `gfx908` or `gfx90a` architecture, and if the example runs as expected, you should see something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - Perf: 1.08153 ms, 119.136 TFlops, 89.1972 GB/s, DeviceGemm_Xdl_CShuffle LoopScheduler: Interwave, PipelineVersion: v1 - -However, running it on a `gfx1030` device should result in the following:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} - b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem - -Don't worry, some operators are supported on `gfx1030` architecture, so you can run a -separate example like:: - - ./bin/example_gemm_dl_fp16 1 1 1 - -and it should return something like:: - - a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} - b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} - c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} - arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} - arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} - arg.c_grid_desc_m_n_{ 3840, 4096} - launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} - Warm up 1 time - Start running 10 times... - Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> - -.. note:: - - A new CMake flag ``DL_KERNELS`` has been added to the latest versions of CK. If you do not see the above results when running ``example_gemm_dl_fp16``, you might need to add ``-D DL_KERNELS=ON`` to your CMake command to build the operators supported on the `gfx1030` architecture. - -You can also run a separate test:: - - ctest -R test_gemm_fp16 - -If everything goes well you should see something like:: - - Start 121: test_gemm_fp16 - 1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - - 100% tests passed, 0 tests failed out of 1 - -Summary -======= - -In this tutorial you took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. In the next tutorial you will run kernels with different configurations to find out the best one for your hardware and task. - -P.S.: If you are running on a cloud instance, don't forget to switch off the cloud instance. diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt old mode 100755 new mode 100644 index ee9f959d94..61f3ba5351 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -28,11 +28,23 @@ add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) + add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) + + +add_example_executable(example_gemm_xdl_fp16_fp8_streamk_v3 gemm_xdl_fp16_fp8_streamk_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3) + add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) +set(GEMM_OPTIONS) +list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16") +example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) + + list(APPEND gpu_list gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -103,3 +115,18 @@ add_example_executable(example_gemm_wmma_bf16 gemm_wmma_bf16.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16) add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_int8) + +add_example_executable(example_gemm_wmma_bf16_v3 gemm_wmma_bf16_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_v3) +add_example_executable(example_gemm_wmma_bf16_pk_i4_v3 gemm_wmma_bf16_pk_i4_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_bf16_pk_i4_v3) +add_example_executable(example_gemm_wmma_fp8_v3 gemm_wmma_fp8_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_v3 gemm_wmma_fp16_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) +add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 9073ffcfc1..434f549443 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -15,6 +15,8 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" @@ -57,8 +59,9 @@ struct ProblemSizeStreamK_universal final ck::index_t StrideB = -1; ck::index_t StrideC = -1; - ck::index_t Grid_size = -1; // defaults to max occupancy - ck::index_t Streamk_sel = 1; // defaults to 1-tile SK + ck::index_t Grid_size = -1; // defaults to max occupancy + ck::index_t Streamk_sel = 1; // defaults to 1-tile SK + ck::StreamKReductionStrategy reduction_strategy = ck::StreamKReductionStrategy::Atomic; }; struct ProblemSizeSplitK final @@ -128,11 +131,12 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl; return false; } @@ -172,7 +176,19 @@ bool parse_cmd_args(int argc, if(argc >= 11) { problem_size.Streamk_sel = std::stoi(argv[10]); - problem_size.Grid_size = std::stoi(argv[11]); + + if(argc >= 12) + { + problem_size.Grid_size = std::stoi(argv[11]); + + if(argc >= 13) + { + int reduction_strategy = std::stoi(argv[12]); + problem_size.reduction_strategy = reduction_strategy == 0 + ? ck::StreamKReductionStrategy::Atomic + : ck::StreamKReductionStrategy::Reduction; + } + } } } else @@ -181,9 +197,12 @@ bool parse_cmd_args(int argc, << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl << "arg10: stream-k select (-1: default config, 0: all DP, 1: 1-tile SK, 2: 2-tile SK)" - << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + << std::endl + << "arg11: Grid_size(-1 for max occupancy)" << std::endl + << "arg12: Reduction strategy (0: Atomic, 1: Reduction)" << std::endl; return false; } @@ -227,13 +246,14 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" - << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: stream-k select (0: all DP, 1: 1-tile SK, 2: 2-tile SK)" + << "\narg11: Grid_size(-1 for max occupancy)" << std::endl; return false; } @@ -277,12 +297,13 @@ bool parse_cmd_args(int argc, } else { - std::cerr << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl - << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl - << "arg10: KBatch" << std::endl; + std::cerr + << "arg1: verification (0=no, 1=CPU, 2=GPU, 3=CPU and GPU)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC (default: -1 or 0)" + << std::endl + << "arg10: KBatch" << std::endl; return false; } diff --git a/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp new file mode 100644 index 0000000000..69ced56c0b --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16_pk_i4_v3.cpp @@ -0,0 +1,253 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 32; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, KPerBlock, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, + ADataType, ADataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_bf16_v3.cpp b/example/01_gemm/gemm_wmma_bf16_v3.cpp new file mode 100644 index 0000000000..1dc5c5286f --- /dev/null +++ b/example/01_gemm/gemm_wmma_bf16_v3.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp new file mode 100644 index 0000000000..359d823ac2 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_fp8_v3.cpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp new file mode 100644 index 0000000000..ec5e48a86a --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; +static constexpr ck::index_t KPerBlock = 32; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 128, 128, KPerBlock, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 1, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, + ADataType, ADataType, PermuteA, PermuteB>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..d3ac184019 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::string device_name = ck::get_device_name(); + if(!(device_name.find("gfx11") != std::string::npos || + device_name.find("gfx12") != std::string::npos)) + { + std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp new file mode 100644 index 0000000000..7225dba721 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 128, 64, + 64, 8, 8, + 16, 16, + 4, 2, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, 1, + 1, 1, S<1, 32, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_wmma_fp8_v3.cpp b/example/01_gemm/gemm_wmma_fp8_v3.cpp new file mode 100644 index 0000000000..0376820b7b --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp8_v3.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::f8_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 128, + 128, 64, 64, + 8, 8, + 16, 16, + 4, 2, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 4>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, + ComputeTypeA, ComputeTypeB>; +// clang-format on + +using ReferenceComputeType = ck::f8_t; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) +{ + if(!ck::is_gfx12_supported()) + { + std::cout << "This kernel support gfx12 only" << std::endl; + + return 0; + } + return !run_gemm_splitk_example(argc, argv); +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp old mode 100755 new mode 100644 diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp index 7c232f1bcf..7178ad46b9 100644 --- a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -192,14 +192,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp b/example/01_gemm/gemm_xdl_bf16_streamk_v3.cpp old mode 100755 new mode 100644 diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp new file mode 100644 index 0000000000..bd38eb17ee --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8_streamk_v3.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2_Streamk_Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_Streamk_V3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 16, 16, + 256, 8, 16, + 16, 16, + 1, 1, + S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_streamk_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_universal_streamk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index 61c5a32d5d..e16f184a20 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -134,7 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -242,14 +242,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp index 468dd699a1..f83d479713 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -161,7 +161,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -274,14 +274,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 5afb3d1554..b55627f3ee 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl #else < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + // this instance has been tested working on gfx950 + // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index 80f7e95d30..266a1e9d3e 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -152,7 +152,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize() / + 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // do GEMM @@ -261,14 +262,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index 7b72461dd9..0575314dff 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -132,7 +132,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -240,14 +240,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp b/example/01_gemm/gemm_xdl_fp8_streamk_v3.cpp old mode 100755 new mode 100644 diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp index 62037f7740..26ea31f20b 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -38,7 +38,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on #else // clang-format off diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 4a0c23cf44..d149fd88f1 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -56,10 +56,10 @@ using CDataType = float; using AccDataType = float; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; template std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index c064ed500c..6c5d9f9fba 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -33,7 +33,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/01_gemm/run_gemm_example_streamk.inc b/example/01_gemm/run_gemm_example_streamk.inc index 438afcf71a..7e43847463 100644 --- a/example/01_gemm/run_gemm_example_streamk.inc +++ b/example/01_gemm/run_gemm_example_streamk.inc @@ -36,7 +36,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/01_gemm/run_gemm_example_streamk_v2.inc b/example/01_gemm/run_gemm_example_streamk_v2.inc index 9ee380d247..2700838bcc 100644 --- a/example/01_gemm/run_gemm_example_streamk_v2.inc +++ b/example/01_gemm/run_gemm_example_streamk_v2.inc @@ -21,6 +21,16 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto Grid_size = problem_size.Grid_size; auto Streamk_sel = problem_size.Streamk_sel; + auto reduction_strategy = problem_size.reduction_strategy; + if(reduction_strategy == ck::StreamKReductionStrategy::Atomic) + { + std::cout << "Using Atomic reduction strategy" << std::endl; + } + else + { + std::cout << "Using Parallel reduction strategy" << std::endl; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) @@ -35,7 +45,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) @@ -152,7 +162,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) Grid_size, a_element_op, b_element_op, - c_element_op); + c_element_op, + reduction_strategy); if(!gemm.IsSupportedArgument(argument)) { @@ -242,7 +253,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << gemm.GetTypeString() << std::endl; + << " GB/s, " << gemm.GetTypeString() + << (reduction_strategy == ck::StreamKReductionStrategy::Atomic ? " (Atomic)" + : " (Reduction)") + << std::endl; } return pass; } diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc index 2b60fa5d28..4adb6f896b 100644 --- a/example/01_gemm/run_gemm_example_v2.inc +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -34,7 +34,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { - if(stride == -1) + if(stride == -1 || stride == 0) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 18731e810e..03c531c1ad 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 87812369bd..5167097b6d 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index c3e6ef7d5d..abf7ef3905 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp index de7af85fb3..67b3e646f7 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm::value, int8_t, InOutDataType>::type; #else - using InOutDataTypeInDevice = InOutDataType; + using InOutDataTypeInDevice = InOutDataType; #endif using DeviceReduceInstance = diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index db162fe444..63a2aea0b3 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -141,8 +141,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); - d_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); + d_tensors_device.resize(group_count); // reserve and update vector size std::size_t flop = 0, num_btype = 0; diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 86b3182a52..7186c22233 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -21,6 +21,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) @@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } - if(hargs_size > 0) + if(config.async_hargs && hargs_size > 0) { hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); - gemm.SetHostKernelArgs(&argument, gemm_hargs); + gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); } if(!gemm.IsSupportedArgument(argument)) @@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - hipStream_t stream0 = nullptr; - hip_check_error(hipStreamCreate(&stream0)); + if(!config.async_hargs) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + } + else + { + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); - hipEvent_t event0 = nullptr; - hip_check_error(hipEventCreate(&event0)); + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); - hip_check_error(hipEventSynchronize(event0)); - hip_check_error(hipStreamSynchronize(stream0)); + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); + } bool pass = true; if(config.do_verification) @@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); exit(0); } diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 1bea1bcf3e..3e3c586dba 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), - {}, + {}, e_device_buf.GetDeviceBuffer(), - {r0_device_buf.GetDeviceBuffer()}, + {r0_device_buf.GetDeviceBuffer()}, M, N, K, StrideA, StrideB, - {}, + {}, StrideE, a_element_op, b_element_op, diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 62295c57eb..42bfea372e 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -207,7 +207,7 @@ int main(int argc, char* argv[]) auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), nullptr, - {}, + {}, c_device_buf.GetDeviceBuffer(), p_reduces, M, @@ -216,9 +216,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - {}, + {}, gemm_element_ops, - {}, + {}, reduce_in_element_ops, reduce_out_element_ops, BatchCount); diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp index f0160b31ce..84f92eba8e 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -71,9 +71,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 256, // BlockSize 256, // MPerBlock 128, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 + 64, // KPerBlock + 16, // AK1 + 16, // BK1 32, // MPerXDL 32, // NPerXDL 4, // MXdlPerWave @@ -84,14 +84,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM + 0, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + 0, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 8c4913dbcc..3582bc5e33 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -212,7 +212,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() / + 2); DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc index 23608a1eea..02b60fe548 100644 --- a/example/27_layernorm2d_fwd/run_layernorm_example.inc +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index cdfd86dff4..c693995140 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -126,10 +126,10 @@ int run(int argc, char* argv[]) if(i < 4) { - std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " - << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " - << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks[" + << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i + << "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i + << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; } switch(init_method) diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index 3756310fd7..9737b0d99b 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index 6a8002025a..1ffbabd04b 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { - bool pass = true; + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index b27358fd9d..06441be860 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp index ffb9f4b584..8f2b7613b5 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index 97a3f89e5e..fc55019fc4 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on #else diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index d2337dcda5..26a03f289d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -129,11 +129,11 @@ int main() auto argument_ptr = device_instance.MakeArgumentPointer( out_dev.GetDeviceBuffer(), {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), current_dim, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 6af8ac6488..1823d4fc0a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_params = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index 54f3a78809..b23128a536 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept } template -inline auto is_valid_axes(const Axes& axes) - -> std::enable_if_t, bool> +inline auto +is_valid_axes(const Axes& axes) -> std::enable_if_t, bool> { using std::empty; if(empty(axes)) @@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes) } template -auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< - detail::is_bidirectional_range_v && detail::is_sized_range_v && - detail::is_bidirectional_range_v && detail::is_sized_range_v, - bool> +auto advance_indices(const Shape& shape, Indices& indices) + -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 853ff791a6..ab6f317bc6 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 9431a8cde4..c40447e1f9 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) std::array inputs = {input_dev_buf.GetDeviceBuffer()}; std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), - output_scaled_casted_dev_buf.GetDeviceBuffer()}; + output_scaled_casted_dev_buf.GetDeviceBuffer()}; std::cout << "Input: " << input.mDesc << std::endl; std::cout << "Scale: " << scale << std::endl; @@ -164,8 +164,8 @@ int main(int argc, char* argv[]) auto launch_transpose_scale = [&]() { auto transposeScale = DeviceElementwisePermuteInstance{}; auto argument = transposeScale.MakeArgumentPointer(dims, - {in_strides}, - {out_strides, in_strides}, + {in_strides}, + {out_strides, in_strides}, inputs, outputs, ScalePassThrough{scale}); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 93034a8b70..2582ea8a11 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 8b88e2482d..57e2feb084 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b_device_buf.GetDeviceBuffer()}, std::array{d_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index eaabccdf2a..ec1b2d6018 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer()}, + b1_device_buf.GetDeviceBuffer()}, std::array{}, e_device_buf.GetDeviceBuffer(), std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp index 6940c20695..f521c51d67 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification, auto device_ew_scale = DeviceElementwiseScale{}; auto scale_invoker = device_ew_scale.MakeInvoker(); auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, - {e_g_n_k_wos_strides}, - {e_g_n_k_wos_strides}, - {conv_device_buf.GetDeviceBuffer()}, - {out_device_buf.GetDeviceBuffer()}, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, scale_convert); if(!device_ew_scale.IsSupportedArgument(scale_argument)) diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc index 1a0b558e2c..f75c01ec61 100644 --- a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() {0, W * C, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 3}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 38b42fefc4..d1e1a51afd 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,17 +1,72 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -# add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) +set(EXAMPLE_COMPILE_OPTIONS) +# Open it when SGBPack branch landed on mainline +# list(APPEND EXAMPLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm -misched=gcn-iterative-max-occupancy-experimental") +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) +add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) +add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) -list(APPEND gpu_list gfx942) +list(APPEND gpu_list gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - # add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) + add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(hip_VERSION_FLAT LESS_EQUAL 600342132) + set(EXAMPLE_COMPILE_OPTIONS) + check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY) + if(HAS_MAX_ILP_SCHEDULING_STRATEGY) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + endif() + example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() + set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") + example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) + example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) set(target 1) endif() endforeach() + +set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") +set(BLOCKSCALE_GEMM_OPTIONS ) +check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) +check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) + +if(hip_VERSION_FLAT LESS 600443483 OR hip_VERSION_FLAT GREATER_EQUAL 700000000) + if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1") + elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup") + endif() +else() + if(HAS_MISCHED_BOTTOMUP) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --misched-bottomup=1") + elseif(HAS_MISCHED_PRERA_DIRECTION) + list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --misched-prera-direction=bottomup") + endif() +endif() + +check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL) +if(HAS_MAX_OCCUPANCY_EXPERIMENTAL) + list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental) +endif() +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_ab_scale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) + +example_compile_options(example_moe_gemm2_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) +example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCALE_GEMM_OPTIONS}) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..69803c7eeb --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(F16); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 32, 128, 128, + 8, 8, + 32, 32, + 1, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, false, 1}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index b54ba5ddfb..5aa978fbf0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_ A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 16, 128, - 256, 16, 16, + 128, 128, + 128, 16, 16, 16, 16, - 1, 2, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 16, 1, 16>, S<8>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 2, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..d64266bccf --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,372 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +void preShuffleBuffer(const FP8* src, FP8* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + Tensor c_m_n({M, N}); + Tensor a_m_k({M, K}); + Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + +#if 1 + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } +#endif + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..fe1eca51b0 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -9,7 +9,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" @@ -140,14 +139,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 128, 128, 128, + 256, 256, 128, 16, 16, - 32, 32, - 2, 2, + 16, 16, + 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 66825edcf9..9fe9fdde78 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,17 +35,19 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -59,35 +60,66 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - e = ck::type_convert(c * d1 * d0); + (void)d0; + (void)d1; + e = ck::type_convert(c); } -}; - -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - // act - float x0 = 0; - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); - e = ck::type_convert(x0); + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); } }; -// using DsLayout = DsLayoutGate; -// using DsDataType = DsDataTypeGate; -using CDEElementOp = MulABScale; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; -// using CDEElementOp = MulABScaleSiluMulGate; +using CDEElementOp = MulABScaleExpertWeight; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { @@ -126,20 +158,22 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = true; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); + +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -150,15 +184,15 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // mn_perxdl MNPerXDL, MNPerXDL, // mn_xdlperwave - MXDLPerWave, NXDLPerWave, + MXDLPerWave, NXDLPerWave, // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -170,15 +204,13 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; - ck::index_t tokens = 128; + ck::index_t sorted_tile_num = 256; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 16384; ck::index_t topk = 2; - // ck::index_t tokens = batch * topk; - if(argc == 1) { // use default case @@ -224,28 +256,23 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{1, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / (valid_tile_num / experts); } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -259,48 +286,54 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; switch(init_method) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -310,16 +343,16 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k.savetxt("a.txt"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -330,7 +363,8 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); - preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -342,7 +376,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -368,9 +403,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) * K * N * experts + + sizeof(B0DataType) * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -392,10 +427,13 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + ActOP, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -404,8 +442,11 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -428,15 +469,15 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..c5328226ff --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +// using EDataType = F16; +using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); +static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; +static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; +static constexpr ck::index_t BLOCKSIZE = 256; + +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale + // clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +#if 1 + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t topk = 2; + // ck::index_t sorted_tile_num = 515; + // ck::index_t valid_tile_num = 512; + // ck::index_t tokens = 8192; + // ck::index_t sorted_tile_num = 15; + // ck::index_t valid_tile_num = 13; + ck::index_t sorted_tile_num = 259; + ck::index_t valid_tile_num = 256; + ck::index_t tokens = 4096; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t topk = 8; + ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 261; + ck::index_t valid_tile_num = 256; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N * 2}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k({tokens, K}); + Tensor b_e_n_k({experts, K, N * 2}); + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + // handle scale before ref. + for(int t = 0; t < tokens; ++t) + { + for(int k = 0; k < K; ++k) + { + a_t_k(t, k) = ck::type_convert(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K); + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N * 2; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm1BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k, + b_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 17f4cd8a3f..f78e6e48a5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -36,17 +36,19 @@ using A0DataType = F8; using B0DataType = I4; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -55,43 +57,74 @@ struct MulABScale __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } template <> __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); + e = ck::type_convert(c); #else - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); #endif } }; -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu +struct MulABScaleExpertWeight { - template + template __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // act - float x0 = 0; -#if CK_USE_PK4_LAYOUT_SHUFFLE - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); -#else - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); -#endif - e = ck::type_convert(x0); + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScale; +static constexpr bool MulRoutedWeight = true; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true + +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -132,53 +165,24 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 -static constexpr ck::index_t MPerBlock = 64; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< - Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - AK1, BK1, - MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; -// clang-format on -#else static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, MPerBlock, 128, 128, + 256, MPerBlock, 64, 128, 16, 32, - 32, 32, - 4, 1, + 16, 16, + 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; // clang-format on -#endif int main(int argc, char* argv[]) { @@ -186,19 +190,16 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape - ck::index_t N = 14336 * 2; + ck::index_t N = 14336; ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 64; + ck::index_t tokens = 644; ck::index_t topk = 2; if(argc == 1) @@ -232,20 +233,20 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0}; + max_token_id.mData = {valid_size}; int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; for(int i = 0; i < sorted_size; i++) { @@ -260,17 +261,21 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -279,31 +284,35 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); @@ -312,6 +321,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -347,7 +357,7 @@ int main(int argc, char* argv[]) int n1 = n % NLane; int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); + tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; @@ -424,7 +434,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -440,20 +451,25 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * experts + + sizeof(B0DataType) / 2 * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -475,10 +491,13 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + Act_OP, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -487,8 +506,11 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -511,13 +533,14 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 0d12441016..6a3986ea32 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -48,7 +47,6 @@ using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; -// using DsLayoutGate = ck::Tuple; using DsLayout = ck::Tuple; // d0: ascale, d1: bscale, d2:expert weight @@ -62,11 +60,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix (void)d0; - e = ck::type_convert(c * d1 * d2); + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -117,16 +123,14 @@ using BElementOp = PassThrough; using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 256; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MXDLPerWave = 16; +static constexpr ck::index_t NXDLPerWave = 4; +static constexpr ck::index_t NPerBlock = 256; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint -// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -135,6 +139,8 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool PerTokenQuant = true; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -159,12 +165,12 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, // S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -176,26 +182,23 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 133; + ck::index_t valid_tile_num = 128; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 128; + ck::index_t tokens = 16384; ck::index_t topk = 2; if(argc == 1) { // use default case } - else if(argc == 3) + else if(argc == 4) { // use default case do_verification = std::stoi(argv[1]); @@ -211,6 +214,18 @@ int main(int argc, char* argv[]) K = std::stoi(argv[5]); tokens = std::stoi(argv[6]); } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); @@ -224,11 +239,11 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0, 0}; + constexpr auto StrideDs = PerTokenQuant ? std::array{1, 1, 0} + : std::array{0, 0, 0}; ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); @@ -236,10 +251,10 @@ int main(int argc, char* argv[]) // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for(int i = 0; i < sorted_tile_num; i++) { - expert_ids.mData[i] = eids[i]; + expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts); } if(tokens * topk > valid_size) { @@ -248,7 +263,7 @@ int main(int argc, char* argv[]) } int token_per_tile = tokens * topk / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -262,13 +277,14 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d0_t_n( + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); @@ -314,12 +330,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); - // d2_e_n.savetxt("d2_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -397,7 +408,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, @@ -437,8 +449,7 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp new file mode 100644 index 0000000000..354957c0d1 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm2_blockscale.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using EDataType = F16; +// using EDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +// using DsLayoutGate = ck::Tuple; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const float& c, const float& d2) const + { + // for real kernel use + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d2); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; +static constexpr bool MulRoutedWeight = true; + +#if 0 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t BLOCKSIZE = 256; +static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); + +static constexpr ck::index_t CShuffleNLane = 16; +static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 2; +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; +static constexpr ck::index_t D2Vec = 1; + +// clang-format off + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + AK1, BK1, + MNPerXDL, MNPerXDL, + MXDLPerWave, NXDLPerWave, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + +#else +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // tokens = 1 + // topk = 1 + // experts = 8 + // per expert: + + constexpr ck::index_t valid_tile_num = + 26; // 13 for 128; 52 for 32; 4096 for ds // > token * topk / MPerBlock + constexpr ck::index_t sorted_tile_num = valid_tile_num + 3; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; +#if 1 + // GEMM shape + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7160; + ck::index_t experts = 256; + ck::index_t tokens = 1; + ck::index_t topk = 8; +#endif + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + + max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + // int eids[] = {0, 1, 3, 3, 3}; + // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7}; //, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + // int eids[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + // 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + // 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + // 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, + // 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, + // 7, 7, + // 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k(HostTensorDescriptor( + {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{1.0, 1.0}); + for(auto i = 0; i < N * K; i++) + { + b0_e_n_k.mData[i] = ck::type_convert(static_cast(0.1)); + b0_e_n_k.mData[i + N * K] = ck::type_convert(static_cast(0.2)); + } + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t num_btype = sizeof(A0DataType) * tokens * K * topk + + sizeof(B0DataType) * K * N * experts + + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k_k({tokens, topk, K}); + Tensor b_e_n_k({experts, K, N}); + Tensor c_t_n({tokens, N}); + + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + for(int k = 0; k < K; ++k) + { + a_t_k_k(t, tk, k) = ck::type_convert(a0_t_k_k(t, tk, k)) * + a1_t_k_k(t, tk, k / Scale_Block_K); + } + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm2BlockScale; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k_k, + b_e_n_k, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 8441862004..3745e3d0af 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -62,11 +62,13 @@ struct MulABScaleExpertWeight EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { (void)d0; + (void)d1; + (void)d2; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d2 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d2); + e = ck::type_convert(c); #endif } // for reference cpu @@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -138,6 +140,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, @@ -148,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -158,9 +161,6 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; @@ -281,7 +281,7 @@ int main(int argc, char* argv[]) break; case 4: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); @@ -298,7 +298,7 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); @@ -407,13 +407,18 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { // not result correct here because output buf not setzero @@ -450,7 +455,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); diff --git a/example/66_complex_contraction_bilinear/CMakeLists.txt b/example/66_complex_contraction_bilinear/CMakeLists.txt old mode 100755 new mode 100644 diff --git a/example/66_complex_contraction_bilinear/README.md b/example/66_complex_contraction_bilinear/README.md old mode 100755 new mode 100644 diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp32.cpp old mode 100755 new mode 100644 diff --git a/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp b/example/66_complex_contraction_bilinear/complex_contraction_bilinear_xdl_fp64.cpp old mode 100755 new mode 100644 diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 9e95c3e007..6ee43aac62 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,10 +1,68 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) -add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) +add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) -add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) +# TODO: Fix RRR +# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) + +add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp6) + +add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf6) + +add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4) + +add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) + +add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle) + +add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp) +add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle) + +set(FP4_MXGEMM_OPTIONS) +list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1") +example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B no-shuffling + scale shuffling +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B no-shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS}) + +# mx moe B shuffling + scale shuffling (async loads) +example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) +example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS}) + +set(FP8_MXGEMM_OPTIONS) +list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1") +example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS}) + +set(FP6_MXGEMM_OPTIONS) +list(APPEND FP6_MXGEMM_OPTIONS -mavx512f) +example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS}) +example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS}) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index 713902588d..007c934b7e 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -8,18 +8,20 @@ Custom verification parameters: # arg2: initialization (0=constant values, 1=integer values, 2=decimal values) # arg3: time kernel (0=no, 1=yes) # arg4: verbosity (0=no info, 1=verbose info) -# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC +# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC # arg11: KBatch -./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 +# arg12: warmup runs pre-timing +# arg13: repeat run count for timing +./bin/example_gemm_mx_fp8 1 1 0 1 ``` Custom tensor shapes: ```bash -./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10 ``` Default invocation: ```bash -# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 -./bin/example_gemm_mx_fp8_fp8_scale +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_bf6.cpp b/example/67_gemm_microscaling/gemm_mx_bf6.cpp new file mode 100644 index 0000000000..34810c2961 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf6.cpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf6x16_pk_t; +using BDataType = ck::bf6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 16; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp new file mode 100644 index 0000000000..58f2dcb010 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 128, // BlockSize: Thread block size + 128, // MPerBlock + 32, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 9a05954c73..2d0585c880 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -36,6 +37,8 @@ struct ExecutionConfig final int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) + int warm_up = 10; + int repeat = 10; }; struct ProblemSizeSplitK final @@ -86,6 +89,8 @@ bool parse_cmd_args(int argc, if(argc >= 12) { problem_size.KBatch = std::stoi(argv[11]); + config.warm_up = std::stoi(argv[12]); + config.repeat = std::stoi(argv[13]); } } else @@ -95,17 +100,101 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl - << "arg11: KBatch" << std::endl; + << "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl + << "arg11: KBatch" << std::endl + << "arg12: warmup runs pre-timing" << std::endl + << "arg13: repeat run count for timing" << std::endl; + return false; } return true; } -template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, + // 2-k))); + + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +template + ck::index_t ScaleBlockSize> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; - - static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - - static constexpr ck::index_t KPerBlock = 64; - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - ADataType, // ADataType - XDataType, // AScaleDataType - BDataType, // BDataType - XDataType, // BScaleDataType - CDataType, // CDataType - AccDataType, // GemmAccDataType - CShuffleDataType, // CShuffleDataType - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - GemmSpec, // GemmSpec - MXVectorSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock - KPerBlock, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - BlkGemmPSched, // BlkGemmPipeSched - BlkGemmPVer, // BlkGemmPipelineVer - ADataType, // ComputeTypeA - BDataType // ComputeTypeB - >; + constexpr bool BPreShuffle = ck::is_same_v; + using BRefLayout = ck::conditional_t; auto M = problem_size.M; auto N = problem_size.N; @@ -186,28 +221,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto f_host_tensor_descriptor = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { if constexpr(std::is_same_v) - { return HostTensorDescriptor({row, col}, {stride, 1}); - } else - { return HostTensorDescriptor({row, col}, {1, stride}); - } }; - auto f_get_default_stride = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { if(stride == -1) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) - { return static_cast(col); - } else - { return static_cast(row); - } } else return static_cast(stride); @@ -222,21 +248,40 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; + if(K % ck::packed_size_v != 0 || K % ck::packed_size_v != 0) + { + throw std::runtime_error("wrong! K must be multiple of packed size."); + }; + // Hardcode scale layouts as per pipeline assumptions // TODO: Allow user to specify scale layouts using AScaleLayout = Row; using BScaleLayout = Col; - auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Padded_M = ck::math::integer_least_multiple(M, ScaleBlockSize); + auto Scale_Stride_AM = + f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + auto b_k_n = + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); + auto b_input = b_k_n; + if constexpr(BPreShuffle) + b_input = std::make_shared>( + f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size + // scales for A and B Tensor a_m_k_scale(f_host_tensor_descriptor( - M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A - Tensor b_k_n_scale(f_host_tensor_descriptor( - K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_k_n_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); + + // shuffled scales for A and B + Tensor a_shuffled_scale(f_host_tensor_descriptor( + Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); + Tensor b_shuffled_scale( + f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); Tensor c_m_n_host_result( f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification @@ -247,54 +292,70 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c { std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n->mDesc << std::endl; std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; } + auto a_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); + else + return ck::type_convert(x); + }; + auto b_data_element = [](float x) { + if constexpr(ck::is_same_v) + return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); + else + return ck::type_convert(x); + }; + + using int_distr = std::uniform_int_distribution; + using float_distr = std::uniform_real_distribution; switch(config.init_method) { case 0: // Initializations for development and debugging - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + + ck::utils::FillConstant{a_data_element(0.5f)}(a_m_k); ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); + + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); + if(config.verbosity > 0) { - std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A = {0.5}" << std::endl; std::cout << "Init A scale = {2.0}" << std::endl; - std::cout << "Init B = {0.5}" << std::endl; - std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Init B = {2.0}" << std::endl; + std::cout << "Init B scale = {0.5}" << std::endl; std::cout << "Expect C = {K}" << std::endl; } break; case 1: - - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - - if constexpr(ck::is_same_v) - { - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - } - else - { - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); - } - + a_m_k.GenerateTensorDistr( + int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5] + b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5] + static_assert(ck::is_same_v); + a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} break; case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + a_m_k.GenerateTensorDistr( + float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2] + a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); + b_k_n_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); break; default: @@ -304,20 +365,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c } } + preShuffleScaleBuffer>(a_m_k_scale.mData.data(), + a_shuffled_scale.mData.data(), + Scale_Padded_M, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if constexpr(BPreShuffle) + { + int NPerXdl = 16; // Fixed 16 + preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); + } + if(config.verbosity > 0) std::cout << "Device memory allocation..." << std::endl; - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); if(config.verbosity > 0) std::cout << "Upload data to device..." << std::endl; a_device_buf.ToDevice(a_m_k.mData.data()); - a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); + b_device_buf.ToDevice(b_input->mData.data()); + b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); + if(config.verbosity > 0) std::cout << "Done." << std::endl; @@ -330,9 +404,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), M, N, @@ -354,13 +428,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c "not consistent with the supported device_gemm arguments."); } + std::size_t total_size = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + + a_shuffled_scale.GetElementSpaceSizeInBytes() + + b_shuffled_scale.GetElementSpaceSizeInBytes(); + const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size); + const int rotating_count = std::max(1, std::min(config.repeat, static_cast(total_cnt))); if(config.verbosity > 0) { std::cout << "Computing GEMM on device..." << std::endl << std::endl; } - float ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + float ave_time = invoker.Run(argument, + StreamConfig{nullptr, + config.time_kernel, + config.verbosity, + config.warm_up, + config.repeat, + rotating_count > 1, + rotating_count}); bool res_verified = true; if(config.do_verification > 0) @@ -387,7 +474,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto ref_argument = ref_gemm.MakeArgument(a_m_k, a_m_k_scale, - b_k_n, + *b_k_n, b_k_n_scale, c_m_n_host_result, PassThrough{}, @@ -402,20 +489,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Comparing results..." << std::endl; } - if(config.init_method == 0) - { - auto expected = static_cast(K); - auto computed = type_convert(c_m_n_device_result(1, 12)); - - res_verified = res_verified && std::abs(expected - computed) <= 0.0f; - std::cout << "\nExpected vs Computed: " << expected << " vs " << computed - << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl - << std::endl; - } - - res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!"); + res_verified = + res_verified && + ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1); if(config.verbosity > 0 && res_verified) std::cout << "Verification Successful!" << std::endl; @@ -428,15 +505,18 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + - std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of + // partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + std::size_t num_btype = + sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = static_cast(num_btype) / 1e6f / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << device_op.GetTypeString() << std::endl; @@ -445,9 +525,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c return res_verified; } -template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..6e1efd266b --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f4x2_pk_t; +using BDataType = ck::f4x2_pk_t; + +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = MFMA; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +// AB DataType: f4x2_pk_t +// Mathmatically, all numbers are represented as f4x2. +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XPackedDataType, // AScaleDataType + BDataType, // BDataType + XPackedDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 512, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 4, // CShuffleNXdlPerWavePerShuffle + S<1, 8, 1, 32>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlockW + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp6.cpp b/example/67_gemm_microscaling/gemm_mx_fp6.cpp new file mode 100644 index 0000000000..615980082d --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp6.cpp @@ -0,0 +1,99 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f6x16_pk_t; +using BDataType = ck::f6x16_pk_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v; // K dimension size per block + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Number of threads per block + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 1, // AK1 number of elements to read at a time when transferring from global memory to LDS + 1, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp new file mode 100644 index 0000000000..e6fe791178 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp new file mode 100644 index 0000000000..fdc4ace471 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + 256, // KPerBlock + 16, // AK1 + 8, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 2, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp deleted file mode 100644 index 393f4a2ea7..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::e8m0_bexp_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp deleted file mode 100644 index dd654a8f69..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::half_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp deleted file mode 100644 index c42d9783be..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::f8_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..aaf0cb3891 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..24ab326391 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp @@ -0,0 +1,545 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t BlockSize = 256; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, BlockSize, + MPerBlock, NPerBlock, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 4096; + ck::index_t K = 6144; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{0.5f}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.5f}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..08ed8e11fb --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,574 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 64, KPerBlock, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({sorted_tile_num + 1})); + max_token_id.mData[0] = valid_size; + + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_k_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_k_n_device_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + + e_t_k_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{0.1f}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k(token_id, k); + } + } + } + + // A/B scale shuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>(b1_e_n_k.mData.data(), + b_scale_preshuffled.mData.data(), + N * 2 * experts, + K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * 2 * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + // FMA * tokens * N * (Gate+Up) * topk * K + + // FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale) + std::size_t(2) * tokens * N * 2 * topk * K + + std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K + + sizeof(B0DataType) / 2 * K * N * 2 * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts + + sizeof(EDataType) * tokens * topk * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm1; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k, + a1_t_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < N; ++n) + { + e_t_k_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_k_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp new file mode 100644 index 0000000000..1b8a7a16e3 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp @@ -0,0 +1,542 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + // d2_e_n.savetxt("weight.txt", "int"); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp new file mode 100644 index 0000000000..829bf9af24 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -0,0 +1,526 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bns.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = sorted_tile_num; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + b0_device_buf.ToDevice(b0_e_n_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp new file mode 100644 index 0000000000..efbd0f0c03 --- /dev/null +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -0,0 +1,584 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F4 = ck::f4x2_pk_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using XDataType = ck::e8m0_bexp_t; +using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F4; +using A1DataType = XPackedDataType; +using B0DataType = F4; +using B1DataType = XPackedDataType; +using EDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using D0DataType = F32; +using D1DataType = F32; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +// d0: ascale, d1: bscale, d2:expert weight +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + e = ck::type_convert(c * d0 * d1 * d2); + } +}; + +using CDEElementOp = MulABScaleExpertWeight; + +// B preshuffle +void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + I64 tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K_pk; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} + +// A, B Scale preshuffle +template +void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) +{ + int MNXdlPack = 2; + int KXdlPack = 2; + + int XdlMNThread = 16; + int XdlKThread = 64 / XdlMNThread; + + int K0 = K / KXdlPack / XdlKThread; // KRepeat + + // The 4 16x128 building blocks will be packed into 1 32x256 for F4 + // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 + + // unfold the MN32xK(256/32) scale buffer + // 4 16 2 2 + // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack + // Then, MNRepeat->KRepeat + + for(int n = 0; n < MN; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat + int tempn = n % (XdlMNThread * MNXdlPack); + int n1 = tempn % XdlMNThread; // i XdlMNThread + int n2 = tempn / XdlMNThread; // i MNXdlPack + + int k0 = k / (XdlKThread * KXdlPack); // i KRepeat + int tempk = k % (XdlKThread * KXdlPack); + int k1 = tempk % XdlKThread; // i XdlKThread + int k2 = tempk / XdlKThread; // i KXdlPack + + int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + + k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + + k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + + k2 * MNXdlPack + n2; + // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + + // k2 * MNXdlPack))); + if constexpr(KLast) + dst[outputIndex] = src[n * K + k]; + else + dst[outputIndex] = src[k * MN + n]; + } + } +} + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 + +static constexpr ck::index_t MPerBlock = 128; +static constexpr bool MulRoutedWeight = true; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle< + A0Layout, B0Layout, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + ScaleBlockSize, 256, + MPerBlock, 128, KPerBlock, + 16, 16, + 16, 16, + 8, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, + 2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // per expert: + // GEMM shape + constexpr ck::index_t sorted_tile_num = 13; + constexpr ck::index_t valid_tile_num = 13; + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + + ck::index_t N = 6144; + ck::index_t K = 4096; + ck::index_t experts = 8; + ck::index_t tokens = 832; + ck::index_t topk = 2; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize; + ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0, 0, 0}; + + ck::index_t KBatch = 1; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1})); + max_token_id.mData[0] = valid_size; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + int eids[sorted_tile_num]{}; + for(int i = 0; i < sorted_tile_num; i++) + { + if(i < valid_tile_num) + { + eids[i] = (i * experts) / valid_tile_num; + } + else + { + eids[i] = 3; + } + } + + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = eids[i]; + } + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + int token_per_tile = tokens * topk / valid_tile_num; + int tokenid = 0; + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a1_t_k_k( + HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + // B preshuffle + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + + // A, B Scale preshuffle + Tensor a_scale_sorted(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor a_scale_preshuffled(HostTensorDescriptor( + {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); + Tensor b_scale_preshuffled( + HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, + {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); + Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); + Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); + + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl; + std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 5.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 8: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize()); + + // A scale sorted + for(int i = 0; i < sorted_size; i++) + { + int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF; + int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF; + + for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++) + { + if(token_id == tokens) + { + a_scale_sorted(i, k) = ck::type_convert(0); + } + else + { + a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k); + } + } + } + + // A, B Scale preshuffle + preShuffleScaleBuffer>(a_scale_sorted.mData.data(), + a_scale_preshuffled.mData.data(), + sorted_size, + K / ScaleBlockSize); + preShuffleScaleBuffer>( + b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k_k.mData.data()); + a1_device_buf.ToDevice(a_scale_preshuffled.mData.data()); + b1_device_buf.ToDevice(b_scale_preshuffled.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + preShuffleBuffer(b0_e_n_k.mData.data(), + b0_preshuffled.mData.data(), + N * experts, + K, + device_op.GetPreShuffleParameters()); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument( + sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + std::array{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + + if(time_kernel) + { + // not result correct here because output buf not setzero + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + // FMA * tokens * N * topk * K + + // FMA * tokens * N * topk * (K/BlockScale) + std::size_t flop = std::size_t(2) * tokens * topk * N * K + + std::size_t(2) * tokens * topk * N * K / ScaleBlockSize; + + std::size_t num_btype = + sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts + + sizeof(XDataType) * tokens * topk * K / ScaleBlockSize + + sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // gemm2 use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor c_t_n({tokens, N}); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeMXGemm2; + + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + a1_t_k_k, + b0_e_n_k, + b1_e_n_k, + d2_e_n, // topk weights + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); + + ref_invoker.Run(ref_argument); + for(int t = 0; t < tokens; ++t) + { + for(int n = 0; n < N; ++n) + { + e_t_n_host_result(t, n) = ck::type_convert(c_t_n(t, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + return ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 64ff2a6813..7bd628edf2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -20,34 +20,35 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME) endfunction(add_example_dependencies EXAMPLE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME) - message("adding example ${EXAMPLE_NAME}") + message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) - message("removing example source file ${source} ") + message(DEBUG "removing example source file ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() @@ -55,91 +56,84 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") - message("removing dl example ${source} ") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") + message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any DPP examples if DPP_KERNELS not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") - message("removing dpp example ${source} ") + #Do not build any DPP examples if DPP_KERNELS not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") + message(DEBUG "removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") - message("removing xdl example ${source} ") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") + message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") + message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any microscaling examples if gfx950 target is not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") - message("removing microscaling example ${source} ") + #Do not build any microscaling examples if gfx950 target is not on the list + if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") + message(DEBUG "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any FP8 examples if CK_ENABLE_FP8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") - message("removing fp8 example ${source} ") + #Do not build any FP8 examples if CK_ENABLE_FP8 not set + if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8") + message(DEBUG "removing fp8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any BF8 examples if CK_ENABLE_BF8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") - message("removing bf8 example ${source} ") + #Do not build any BF8 examples if CK_ENABLE_BF8 not set + if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8") + message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply_xdl_fp8_bpreshuffle") - message("Skipping ${source} example for current target") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endif() endforeach() #only continue if there are some source files left on the list + set(source_name_list "") + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + list(APPEND source_name_list ${source_name}) + endforeach() if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 - message("trimming targets for ${FILE_NAME}") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + message(DEBUG "trimming targets for ${FILE_NAME}") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() - #message("add_example returns ${result}") + message(DEBUG "add_example returns ${result}") if(result EQUAL 0 AND NOT "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("adding to SMOKE EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "SMOKE_TEST") add_dependencies(smoke ${EXAMPLE_NAME}) elseif(result EQUAL 0 AND "${EXAMPLE_NAME}" IN_LIST REGRESSION_EXAMPLES) - #message("Adding to REGRESSION EXAMPLE FILTER ${EXAMPLE_NAME}") set_tests_properties(${EXAMPLE_NAME} PROPERTIES LABELS "REGRESSION_TEST") add_dependencies(regression ${EXAMPLE_NAME}) endif() @@ -153,83 +147,89 @@ function(add_example_dependencies EXAMPLE_NAME FILE_NAME) endfunction(add_example_dependencies EXAMPLE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) - message("adding example ${EXAMPLE_NAME}") + message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message("removing example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + set(test 0) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message(DEBUG "removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set + set(source_name_list "") foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") - message("removing dl example ${source} ") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") + message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") - message("removing xdl example ${source} ") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") + message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") - message("removing wmma example ${source} ") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") + message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() + list(APPEND source_name_list ${source_name}) endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + if(source_name_list MATCHES "_xdl") + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() - - #message("add_example returns ${result}") + + message(DEBUG "add_example returns ${result}") set(result ${result} PARENT_SCOPE) endfunction(add_example_executable_no_testing EXAMPLE_NAME) +function(example_compile_options EXAMPLE_NAME) + if(TARGET ${EXAMPLE_NAME}) + target_compile_options(${EXAMPLE_NAME} ${ARGN}) + endif() +endfunction(example_compile_options) + # add all example subdir file(GLOB dir_list LIST_DIRECTORIES true *) FOREACH(subdir ${dir_list}) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9ba3a453fc..bd03aee924 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -1,7 +1,7 @@ # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING - "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") + "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(FMHA_FWD_ENABLE_APIS STREQUAL "all") set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS}) endif() @@ -17,24 +17,45 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) list(APPEND FMHA_FWD_ENABLE_APIS "fwd") endif() +file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py +) +# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change +set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}") + string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") +set(FMHA_FWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api ${FMHA_FWD_APIS} + --optdim 32,64,128,256 + # --filter fmha_fwd... +) +set(FMHA_BWD_CODE_GEN_COMMON_ARGS + ${CMAKE_CURRENT_LIST_DIR}/generate.py + --api bwd + --receipt 3 + --optdim 32,64,128,256 + # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... +) + # generate a list of kernels, but not actually emit files at config sta execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.") endif() execute_process( - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) - message( FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") + message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.") endif() # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -44,20 +65,22 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 + COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS} + --output_dir ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${CODE_GEN_SCRIPTS} ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding example ${EXAMPLE_FMHA_FWD}") +message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) @@ -65,7 +88,7 @@ target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding example ${EXAMPLE_FMHA_BWD}") +message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}") add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp) target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) @@ -73,7 +96,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS}) # NOTE: this is dangerous since will change the whole kernel to flush denormals # WIP with compiler team for an exp2 intrinsic..., then remove this if(NOT DEFINED FMHA_FWD_FAST_EXP2) - set(FMHA_FWD_FAST_EXP2 true) + set(FMHA_FWD_FAST_EXP2 true) endif() set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) @@ -82,9 +105,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # ... because they are auto-generated if(FMHA_FWD_FAST_EXP2) - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) else() - list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) endif() list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero) @@ -102,6 +125,13 @@ else() list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0) endif() +# conditionally enable call to the pagedkv_prefill API in fmha_fwd example +if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) +endif() + # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 12414a20ed..72109a660b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -71,6 +71,7 @@ args: -drop_seed seed for random number generator (default:1) -drop_offset offset for random number generator (default:0) -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) ``` diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..42a9d5148a 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -114,15 +114,22 @@ LAYOUT_MAP = { PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload", } PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", + "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD", } BOOL_MAP = { "t" : "true", - "f" : "false" + "f" : "false", + True : "true", + False : "false", } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py new file mode 100644 index 0000000000..5d55e8bc36 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -0,0 +1,626 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass, field +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_BATCH_PREFILL_PIPELINE_MAP = { + "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + false, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + +#include + +template<> +float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" +FMHA_FWD_API=""" +#include + +namespace {{ +bool get_num_cus(unsigned& num_cu) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cu = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{ + float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>; + return fmha_batch_prefill_(s, a); + }} +""" + +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + constraint : CppConstraint + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant], + F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) + +class KernelComponentFactory: + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + 128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + else: + return None + + @staticmethod + def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + assert False + return pipelines + +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if 128 in result.keys(): + result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = CustomFactory.get_hdim_tile_size_dict(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6326a97f8e..bb3a0587e7 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation import copy @@ -7,22 +7,14 @@ from dataclasses import dataclass import fnmatch import itertools from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Tuple, Dict, Literal, Any +from collections import defaultdict from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file -BWD_DQDKDV_PIPELINE_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR", -} - -BWD_DQDKDV_PIPELINE_ENUM_MAP = { - "kr_ktr_vr_iglp" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP", - "kr_ktr_vr" : "ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR", -} - FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py @@ -39,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>; using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>; using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>; using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>; +using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>; // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // G0&G2 -> GSdP @@ -54,12 +47,14 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; + fmha_warp_tile2_{F_idx}, + {F_maxq}>; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, - {F_skpad}, +using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; -using fmha_bwd_pipeline_{F_idx} = {F_pipeline}; +using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline; using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, - {F_skpad}, + false, {F_dpad}>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, - {F_skpad}, + false, {F_dvpad}>>; +using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< + ck_tile::Default2DEpilogueProblem::AccDataType, + typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, + false, + {F_dpad}>>; + using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel; + fmha_bwd_dv_epilogue_{F_idx}, + fmha_bwd_dq_epilogue_{F_idx}>; using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, - {F_pipeline_enum}, fmha_mask_{F_idx}, fmha_dropout_{F_idx}, {F_bias}, {F_dbias}, - {F_spad}, - {F_skpad}, {F_dpad}, {F_dvpad}, - {F_deterministic}>; + {F_deterministic}, + {F_trload}, + {F_maxq}>; #include @@ -152,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_co ck_tile::stream_config{{s.stream_id_}}); }} +template <> +int fmha_bwd_dq_dk_dv_maxq_() +{{ + using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx}; + return k_::kMaxSeqLenQ; +}} + template <> std::string fmha_bwd_dq_dk_dv_get_name_() {{ @@ -167,135 +176,59 @@ FMHA_BWD_API=""" template float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) {{ - if(s.log_level_ > 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} - ); + if constexpr (!std::is_same_v) + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_convert_dq_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + ); + }} + else + {{ + if(s.log_level_ > 0) + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << "@" << fmha_bwd_dq_dk_dv_get_name_() << std::flush; + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }} + ); + }} }} template <> float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ + const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; }} """ -FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ -{F_hdim_case} - }} -""" -FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ -{F_inner_dispatch} - }} +def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str: + lines = [ + f"{'if' if if_ == 0 else 'else if'}({F_cond})", + "{", + *[' ' + line for line in F_body.split('\n') if line.strip() != ''], + "}", + ] + return '\n'.join(' ' * indent + line for line in lines) + '\n' + + +FMHA_BWD_API_INNER_DISPATCH=""" +{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && + ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>; + r = fmha_bwd_>(s, a); + return r; +}} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_deterministic}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dpad}, {F_deterministic}>; - r = fmha_bwd_(s, a); - return r; - }} -""" - -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - deterministic : str - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 64 == 0' - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - hdim_int = int(hdim) - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) +# M0 size for 1d kernels (dot/convert) +M0_1D = 64 # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) @@ -303,7 +236,7 @@ class FmhaBwdApiPool: # GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) # GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) # Is it necessary to distinguish between K0~K4? -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen @@ -330,20 +263,20 @@ class FmhaBwdDQDKDVTileSize: F_wn1 : int # warp size along n in gemm1/gemm3 F_wk1 : int # warp size along k in gemm1/gemm3 F_occupancy : int # occupancy + max_seq_q : int = 0 + @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ - f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}" -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_spad : str # true/false - F_skpad : str # F_dpad : str # F_dvpad : str # F_bias : str # @@ -352,8 +285,8 @@ class FmhaBwdDQDKDVKernel: F_mask : str # value from MASK_MAP F_mode : str # value from MODE_MAP F_deterministic : str # - F_pipeline : str # mask_impl : str # + F_trload : str # @property def template(self) -> str: @@ -386,8 +319,6 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_spad = BOOL_MAP[self.F_spad], - F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], F_bias = BIAS_MAP[self.F_bias], @@ -397,21 +328,20 @@ class FmhaBwdDQDKDVKernel: F_mask = get_mask_map(self.mask_impl)[self.F_mask], F_mode = MODE_MAP[self.F_mode], F_deterministic = BOOL_MAP[self.F_deterministic], - F_pipeline_enum = BWD_DQDKDV_PIPELINE_ENUM_MAP[self.F_pipeline], - F_pipeline = BWD_DQDKDV_PIPELINE_MAP[self.F_pipeline]) + F_trload = BOOL_MAP[self.F_trload], + F_maxq = self.F_tile.max_seq_q + ) @property def name(self) -> str: def pad_name() -> str: n = '' - if self.F_spad == 't': n += 's' - if self.F_skpad == 't' : n += 'sk' if self.F_dpad == 't' : n += 'd' if self.F_dvpad == 't' : n += 'dv' if n != '' : n = 'p' + n return n pn = pad_name() - n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name + f'_{self.F_pipeline}' + n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name if pn != '' : n += f'_{pn}' else: n += '_npad' @@ -433,122 +363,34 @@ class FmhaBwdDQDKDVKernel: if self.F_deterministic == 't' : n += '_deterministic' else: n += '_ndeterministic' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' return n @property def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad, - deterministic=self.F_deterministic - ) - # TODO: design a more practical way to do it -# this is current supported tile size & pipeline. -def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"], - '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - "kr_ktr_vr_iglp", "kr_ktr_vr"] - } +# this is current supported tile size. +def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: + if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': + return [ + FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), + # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), + FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), + ] else: - return None - -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - if ("wg32" in dropout): - continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # Aiter (mha_bwd) integration - elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) + return [] FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -560,7 +402,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, - /* BlockSize = */ 64, + /* BlockSize = M0 = */ 64, {F_hdim}, {F_mode}, fmha_bwd_dot_do_o_trait_{F_idx}>; @@ -608,7 +450,7 @@ std::string fmha_bwd_dot_do_o_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdOGradDotOKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -648,44 +490,6 @@ class FmhaBwdOGradDotOKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - gen.append(k) - - return gen - FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -752,7 +556,7 @@ std::string fmha_bwd_convert_dq_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdConvertQGradKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -764,6 +568,7 @@ class FmhaBwdConvertQGradKernel: F_mode : str # value from MODE_MAP F_occupancy : int # F_deterministic : str # + disabled : bool # sometimes this kernel is not used @property def template(self) -> str: @@ -800,83 +605,275 @@ class FmhaBwdConvertQGradKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 +@dataclass(frozen=True) +class FmhaBwdApiTrait: + idx : int # this is not a tunable, but a counter to differentiate symbol + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : int + dtype : str # data type + mode : str # value from MODE_MAP + tile : FmhaBwdDQDKDVTileSize + mask : str + bias : str + dbias : str + dropout : str + spad1d : str # spad for 1d kernels (dot/convert) + dpad : str + dvpad : str + deterministic : str + mask_impl : str + tr_load : str - gen = list() + @property + def bm0(self) -> int: + return self.tile.F_bm0 + @property + def bn0(self) -> int: + return self.tile.F_bn0 + @property + def bhdq(self) -> int: + return self.tile.F_bhdq + @property + def bhdv(self) -> int: + return self.tile.F_bhdv - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - tile = d[hdim_str][0] - if (mode == "group" and spad == "f"): + @property + def scheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad1d == 't': + return f'a.seqlen_q % {M0_1D} != 0' + else: # self.spad1d == 'f' + return f'a.seqlen_q % {M0_1D} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + + @property + def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, + F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + + @property + def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: + return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, + F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout, + F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load) + + @property + def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) + + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + # TODO: do we need to check duplication? + self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait)) + + @staticmethod + def if_(i: int) -> str: + return 'if' if i == 0 else 'else if' + + def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str: + inners = "" + i = 0 + for trait in traits: + inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, + F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled]) + i += 1 + return inners + + @staticmethod + def trload_sort_key(tf): + return 0 if tf == 't' else 1 # sort 't' before 'f' + + @staticmethod + def max_seq_q_sort_key(max_seq_q): + return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end + + @staticmethod + def max_seq_q_cond(max_seq_q: int) -> str: + if max_seq_q == 0: + return 'true /* no seqlen_q limit */' + else: + return f'a.seqlen_q <= {max_seq_q}' + + @staticmethod + def dtype_cond(dtype: str) -> str: + return f't.data_type.compare("{dtype}") == 0' + + @staticmethod + def hdim_cond(hdim: int) -> str: + return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}' + + @property + def api(self) -> str: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true /* no trload requirement */" + } + per_tr_load = '' + for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key): + per_max_seq_q = '' + for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key): + per_dtypes = '' + for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]): + per_hdim_case = '' + for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]): + traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim] + inners = self._api_innders(traits) + per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners) + per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case) + per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes) + per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4) + if not per_tr_load: + # empty string we add some ignore to suppress warning in api + per_tr_load += ' (void)t ; (void)s ; (void)a;' + result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load) + return result.replace('\n\n', '\n') + +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: + if filter_list == '': + filter_list = '*@*@*' + filters = filter_list.split('@') + filters.extend(['*'] * (3 - len(filters))) + filter_dot_do_o = filters[0] + filter_convert_dq = filters[1] + filter_dq_dk_dv = filters[2] + + # use dict as ordered set + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + api_pool = FmhaBwdApiPool(mask_impl) + + for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): + tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) + for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" + hdim = tile.F_bhdq + if (mode == "group") and (spad1d == "f"): continue - k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, - F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): + if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0: + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if tr_load == "t" and (dpad == "t" or dvpad == "t"): + continue # tr_load cannot work with dpad or dvpad + t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load) + + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): + continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue + if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + elif receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: continue # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # Aiter (mha_varlen_bwd) integration elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - gen.append(k) + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + gen_dot_do_o[t.dot_do_o_kernel] = True + gen_dq_dk_dv[t.dq_dk_dv_kernel] = True + if not t.convert_dq_kernel.disabled: + gen_convert_dq[t.convert_dq_kernel] = True + api_pool.register_dq_dk_dv_traits(t) - return gen + return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api) + for k in kernels_dot_do_o: + update_file(output_dir / k.filename, k.template) + for k in kernels_convert_dq: + update_file(output_dir / k.filename, k.template) + for k in kernels_dq_dk_dv: + update_file(output_dir / k.filename, k.template) -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) -def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) - -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) - -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - - with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") +def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: + _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) + with file_path.open("a") as f: + for k in kernels_dot_do_o: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_dq_dk_dv: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_convert_dq: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e5d11c6dc9..f614f42e6b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -3,14 +3,16 @@ # generate kernel instances to speed up compilation import copy -from dataclasses import dataclass +from dataclasses import dataclass, field import fnmatch import itertools +import os from pathlib import Path from typing import List, Optional, Tuple from codegen.cmake_config import * from codegen.cpp_symbol_map import * +from codegen.utils import update_file DTYPE_BITS = { @@ -26,12 +28,14 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + 192: 192, 256: 256 } FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" """ @@ -51,12 +55,17 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, false, {F_lse}, {F_dropout}, {F_squant}, - {F_occupancy}>; + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< @@ -73,7 +82,9 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, + {F_trload}, fmha_trait_{F_idx}>; using fmha_pipeline_{F_idx} = {F_pipeline}< @@ -88,7 +99,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; #include @@ -107,13 +118,64 @@ float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" FMHA_FWD_API=""" +#include + +#include + +namespace {{ +bool get_num_cus(unsigned& num_cus) {{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device"); + return false; + }} + + hipDeviceProp_t props{{}}; + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) {{ + fprintf(stderr, "failed to get device properties"); + return false; + }} + + num_cus = props.multiProcessorCount; + return true; +}} + +unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{ + const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; + const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1 + + return batch * nheads * num_m_blocks * num_n_blocks; +}} +}} // namespace + float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ float r = -1; + + [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate + + unsigned num_cus; + if (!get_num_cus(num_cus)) {{ + return r; + }} + + [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ + return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); + }}; + + const bool has_load_tr = ck_tile::is_load_tr_supported(); + {F_dispatch} return r; }} """ +FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{ +{F_dtype_case} + }} +""" + FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} @@ -123,52 +185,75 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && - ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>; return fmha_fwd_(s, a); }} """ +@dataclass +class CppConstraint: + bool_expr: str = None + + def __str__(self): + if self.bool_expr is None: + return 'true' + else: + return f'{self.bool_expr}' + + def __and__(self, other): + return CppConstraint(f'({str(self)}) && ({str(other)})') + @dataclass class FmhaFwdApiTrait: pipeline_tag : str # sync with fmha_fwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along qk seqlen - bk0 : int # tile size along qk gemm unroll - bn1 : int # tile size along v head_dim - bk1 : int # tile size along kv gemm unroll - bk0max : int - vlayout : str - mask : str - bias : str # - lse : str # - dropout : str - squant : str # - spad : str - skpad : str - dpad : str - dvpad : str + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + tr_load : str + constraint : CppConstraint @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' @property def scheck(self) -> str: if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true - if self.pipeline_tag == 'qr_async': + if self.pipeline_tag in ['qr_async', 'qr_async_trload']: if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False + + @property + def seqtune(self) -> str: + if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + else: + return f'a.seqlen_q <= {self.bm0}' @property def skcheck(self) -> str: @@ -176,9 +261,12 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag == 'qr_async_trload': + if self.skpad == 't' : return 'true' + else: return 'true' else: assert False @property @@ -187,7 +275,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -199,7 +287,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -209,16 +297,20 @@ class FmhaFwdApiTrait: class FmhaFwdPipeline: tag : str - F_vlayout : str # row/col - F_spad : str # true/false - F_skpad : str # - F_dpad : str # - F_dvpad : str # - F_bias : str # true/false - F_lse : str # - F_dropout : str # - F_squant : str # - F_mask : str # value from MASK_MAP + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + F_trload : str # true/false + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) @property def name(self) -> str: @@ -235,6 +327,9 @@ class FmhaFwdPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -251,8 +346,15 @@ class FmhaFwdPipeline: if self.F_dropout == 't' : n += '_dropout' else: n += '_ndropout' + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' + + if self.F_trload == 't' : n += '_trload' + else: n += '_ntrload' + return n class FmhaFwdApiPool: @@ -264,59 +366,71 @@ class FmhaFwdApiPool: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][hdim].append(copy.copy(trait)) @property def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], - F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, - F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, - F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: + tr_load_cond_map = { + "t": "has_load_tr", + "f": "true" + } + + per_tr_load =str() + for tr_load in ["t", "f"]: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_constraint=trait.constraint, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes) + if not per_tr_load: # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + per_tr_load += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load) @dataclass class FmhaFwdTileSize: - F_bm0 : int # tile size along q seqlen (block size) - F_bn0 : int # tile size along k seqlen - F_bk0 : int # tile size along qk gemm unroll - F_bn1 : int # tile size along v head_dim - F_bk1 : int # tile size along kv gemm unroll - F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) - F_rm0 : int # number of warps for gemm0 along q seqlen - F_rn0 : int # number of warps for gemm0 along k seqlen - F_rk0 : int # number of warps for gemm0 along head dim q (not used) - F_rm1 : int # number of warps for gemm1 along q seqlen - F_rn1 : int # number of warps for gemm1 along head dim v - F_rk1 : int # number of warps for gemm1 along k seqlen (not used) - F_wm0 : int # gemm0 warp size along m - F_wn0 : int # gemm0 warp size along n - F_wk0 : int # gemm0 warp size along k - F_wm1 : int # gemm1 warp size along m - F_wn1 : int # gemm1 warp size along n - F_wk1 : int # gemm1 warp size along k - F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint()) + @property def name(self) -> str: return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ @@ -365,15 +479,18 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], F_occupancy = self.F_tile.F_occupancy, F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], F_mode = MODE_MAP[self.F_mode], - F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag], + F_trload = BOOL_MAP[self.F_pipeline.F_trload]) @property def name(self) -> str: @@ -399,6 +516,7 @@ class FmhaFwdKernel: bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, @@ -406,33 +524,45 @@ class FmhaFwdKernel: spad=self.F_pipeline.F_spad, skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, - dvpad=self.F_pipeline.F_dvpad) + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip, + tr_load=self.F_pipeline.F_trload, + constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint) -# TODO: design a more practical way to do it -# this is current supported tile size per hdim -def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': - return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - } - elif dtype == 'fp8' or dtype == 'bf8': - return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - } - else: - return None +class KernelComponentFactory: + # TODO: design a more practical way to do it + # this is current supported tile size per hdim + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + # (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } + else: + return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + @staticmethod + def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -440,33 +570,29 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: - # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + if hdim == 256 and hdim_v == 256: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) + pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -474,26 +600,45 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm assert False return pipelines +class CustomFactory(KernelComponentFactory): + @staticmethod + def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: + result = KernelComponentFactory.get_hdim_tile_size_dict(dtype) + if dtype == 'fp16' or dtype == 'bf16': + if (128, 128) in result.keys(): + result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate'))) + return result + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: gen = list() api_pool = FmhaFwdApiPool(mask_impl) + factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory + for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) + d = factory.get_hdim_tile_size_dict(dtype) if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: + if (hdim, hdim_v) == (192, 128): # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -504,12 +649,16 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' if not cond: continue # PyTorch integration @@ -518,6 +667,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_bias in ['no', 'bias'] cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' if not cond: continue # Aiter(mha_fwd) integration @@ -536,26 +688,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) return (api_pool, gen) def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) + update_file(autogen_dir / kernel.filename, kernel.template) def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f243020dc4..2e5bc2bd3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -326,12 +326,21 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -343,15 +352,15 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c6d1a01792..b2d962cd74 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -34,17 +34,18 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + # 160: 160, 256: 256 } FMHA_FWD_SPLITKV_PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS", - "qr_async" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVSAsync", } FMHA_FWD_SPLITKV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; namespace {{ @@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, /*kHasBiasGrad=*/false, {F_lse}, @@ -85,16 +87,19 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, fmha_shape, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait>; using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = ck_tile::FmhaFwdSplitKVKernel; @@ -111,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -265,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -308,6 +313,7 @@ class FmhaFwdSplitKVApiTrait: bk0max : int vlayout : str mask : str + logits : str bias : str # lse : str # squant : str # @@ -320,7 +326,7 @@ class FmhaFwdSplitKVApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @property @@ -378,6 +384,7 @@ class FmhaFwdSplitKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_squant : str # @@ -399,6 +406,9 @@ class FmhaFwdSplitKVPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -440,10 +450,10 @@ class FmhaFwdSplitKVCombinePipeline: n = f'{self.tag}' if pn != '' : n += f'_{pn}' else: n += '_npad' - + if self.F_lse == 't' : n += '_lse' else: n += '_nlse' - + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' return n @@ -473,7 +483,7 @@ class FmhaFwdSplitKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -539,6 +549,7 @@ class FmhaFwdSplitKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -572,6 +583,7 @@ class FmhaFwdSplitKVKernel: bk1=self.F_tile.F_bk1, bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, @@ -624,8 +636,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + # '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -642,8 +655,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + # '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -655,7 +669,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -669,26 +683,21 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): - # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: - # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) elif dtype in ['fp8', 'bf8']: - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -712,6 +721,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -722,6 +734,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -730,6 +745,15 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16, bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= mode == 'batch' + if not cond: + continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: cond = dtype in ['fp16', 'bf16'] @@ -738,12 +762,19 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -790,12 +821,20 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Aiter(mha_varlen_fwd) integration if receipt == 200: cond = dtype in ['fp16', 'bf16'] cond &= mode == "group" if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -807,27 +846,27 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py new file mode 100644 index 0000000000..650ebaf80e --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -0,0 +1,585 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_FWD_PAGEDKV_PIPELINE_MAP = { + "qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS" +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, //lse + {F_pagedkv}, //pagedkv + {F_squant}, + {F_occupancy}, + {F_skip}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdPagedKVKernel; + +using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + +#include + +template<> +float fmha_fwd_pagedkv_(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>; + return fmha_fwd_pagedkv_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + pagedkv : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + skip : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr_pagedkv', 'qs']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_pagedkv : str # + F_squant : str # + F_mask : str # value from MASK_MAP + F_skip : str # true/false + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_skip == 't' : n += '_skip' + else: n += '_nskip' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + + if self.F_pagedkv == 't' : n += '_pagedkv' + else: n += '_npagedkv' + + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_skip = BOOL_MAP[self.F_pipeline.F_skip], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + pagedkv=self.F_pipeline.F_pagedkv, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad, + skip=self.F_pipeline.F_skip) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + elif dtype in ['fp8', 'bf8']: + # TODO + None + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + # if pipeline.F_pagedkv == 'f': + # continue + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' : + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_skip == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/utils.py b/example/ck_tile/01_fmha/codegen/utils.py new file mode 100644 index 0000000000..e3bbb18c42 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/utils.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import os.path as path + + +def update_file(file_path, content): + """Update the file at file_path with the given content if it differs from the existing content. + + It avoids unnecessary touching of the file which triggers rebuilds + """ + + existing_content = "" + if path.exists(file_path): + with open(file_path, "r") as file: + existing_content = file.read() + if existing_content == content: + return + with open(file_path, "w") as file: + file.write(content) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index eaf99529f3..9c2907778f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_bwd.hpp" #include "ck_tile/host.hpp" @@ -355,7 +355,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); - assert(slopes.size() == nhead); + assert(slopes.size() == static_cast(nhead)); if(bias.rank_info == 0) { // alibi in 1*h @@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser) if(p_drop > 0) { - p_hp_host_ref.ForEach( - [&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); }); + p_dropped_hp_host_ref = p_hp_host_ref; randval_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]); }); ck_tile::reference_batched_dropout( p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop); - p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType(); } else { - p_hp_host_ref.ForEach([&](auto& self, auto idx) { - p_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + p_lp_host_ref = p_hp_host_ref.template CopyAsType(); } // O = P * V @@ -798,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser) } } + // set to bad values to check if the kernel writes to these buffers + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + dq_buf.ToDevice(dq_host.data()); + dk_buf.ToDevice(dk_host.data()); + dv_buf.ToDevice(dv_host.data()); + o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); dq_buf.SetZero(); @@ -806,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::stream_config stream_config_v{ nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")}; + + printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, " + "bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n", + fmha_traits.hdim_q, + fmha_traits.hdim_v, + fmha_traits.data_type.c_str(), + fmha_traits.is_group_mode, + static_cast(fmha_traits.mask_type), + static_cast(fmha_traits.bias_type), + fmha_traits.has_dbias, + fmha_traits.has_dropout, + fmha_traits.is_store_randval, + fmha_traits.is_deterministic); + fflush(stdout); fmha_bwd(fmha_traits, fmha_args, stream_config_v); dq_buf.FromDevice(dq_host.data()); @@ -854,29 +871,27 @@ bool run(const ck_tile::ArgParser& arg_parser) } // dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i) - ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) { - AccDataType do_dot_o = 0; - for(int o = 0; o < hdim_v; o++) - { - auto idx_gmo = idx_gmn; - idx_gmo[2] = o; - do_dot_o += ck_tile::type_convert(do_host_ref(idx_gmo)) * - ck_tile::type_convert(o_host_refs[wb](idx_gmo)); - } - self(idx_gmn) = ck_tile::type_convert( - p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o)); - }); + ck_tile::make_ParallelTensorFunctor( + [&](auto i0, auto i1, auto i2) { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += ck_tile::type_convert(do_host_ref(i0, i1, o)) * + ck_tile::type_convert(o_host_refs[wb](i0, i1, o)); + } + ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert( + p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o)); + }, + ds_hp_host_ref.mDesc.get_lengths()[0], + ds_hp_host_ref.mDesc.get_lengths()[1], + ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); if(use_dbias) { - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - dbias_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + dbias_host_ref = ds_hp_host_ref.template CopyAsType(); } - ds_hp_host_ref.ForEach([&](auto& self, auto idx) { - ds_lp_host_ref(idx) = ck_tile::type_convert(self(idx)); - }); + ds_lp_host_ref = ds_hp_host_ref.template CopyAsType(); // dV = P_drop^T@dO^T // dV = P^T@dO^T w/o dropout diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 9179dbd9be..8d35b2d12c 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" @@ -155,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = [&] { + constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0; + const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr; + const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq; + const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq; + const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq; + // create group mode kernel arguments if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode) { @@ -169,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.dk_ptr, args.dv_ptr, args.dbias_ptr, - args.dq_acc_ptr, + dq_ptr, args.seqstart_q_ptr, args.seqstart_k_ptr, args.seqlen_k_ptr, @@ -184,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - args.stride_dq_acc, + stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -195,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - args.nhead_stride_dq_acc, + nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, @@ -219,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.dk_ptr, args.dv_ptr, args.dbias_ptr, - args.dq_acc_ptr, + dq_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, @@ -233,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.stride_bias, args.stride_randval, args.stride_do, - args.stride_dq_acc, + stride_dq, args.stride_dk, args.stride_dv, args.stride_dbias, @@ -244,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.nhead_stride_randval, args.nhead_stride_do, args.nhead_stride_lsed, - args.nhead_stride_dq_acc, + nhead_stride_dq, args.nhead_stride_dk, args.nhead_stride_dv, args.nhead_stride_dbias, @@ -255,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args) args.batch_stride_randval, args.batch_stride_do, args.batch_stride_lsed, - args.batch_stride_dq_acc, + batch_stride_dq, args.batch_stride_dk, args.batch_stride_dv, args.batch_stride_dbias, @@ -357,31 +364,17 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args) template + bool kIsDeterministic_, + bool kUseTrLoad_, + ck_tile::index_t MaxSeqLenQ_> struct fmha_bwd_dq_dk_dv_traits_ { - static constexpr ck_tile::index_t HDim = HDim_; - using DataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = kIsGroupMode_; - static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_; - using FmhaMask = ck_tile::remove_cvref_t; - using FmhaDropout = ck_tile::remove_cvref_t; - static constexpr auto BiasEnum = BiasEnum_; - static constexpr bool kHasBiasGrad = kHasBiasGrad_; - static constexpr bool kPadS = kPadS_; - static constexpr bool kPadSK = kPadSK_; - static constexpr bool kPadD = kPadD_; - static constexpr bool kPadDv = kPadDv_; - static constexpr bool kIsDeterministic = kIsDeterministic_; }; template @@ -392,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args); template std::string fmha_bwd_dq_dk_dv_get_name_(); +template +int fmha_bwd_dq_dk_dv_maxq_(); template struct fmha_bwd_dot_do_o_traits_ diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..d0f8e3798c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[]) "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") @@ -176,50 +178,30 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if(batch_nhead_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + max_splits = std::min({max_splits, num_SMs}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); + max_efficiency = eff; } + efficiency.push_back(eff); } for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) - { - continue; - } if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); @@ -232,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int override_num_splits_if_necessary( int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { + (void)hdim_v; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) @@ -248,15 +231,13 @@ int override_num_splits_if_necessary( // tile size should match the generate.py const int kM0 = 64; - const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); if(num_splits < 1 && p_drop == 0.0f) { return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); } return num_splits; @@ -342,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser) } ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ + CK_TILE_FMHA_FWD_PAGEDKV_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -358,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); -#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API +#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API) if(use_cache_batch_idx) { std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option" @@ -416,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + std::string squant_str = arg_parser.get_str("squant"); bool squant = [&]() { if(squant_str == "auto") @@ -538,13 +522,13 @@ bool run(const ck_tile::ArgParser& arg_parser) max_seqlen_k = real_seqlen_k; } - flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + - static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + + static_cast(2) * mask.get_unmaskarea() * hdim_v); num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k + sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k); } } @@ -564,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cerr << "num_splits greater than 128 is not supported" << std::endl; return false; } -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < p_drop && (1 < num_splits || use_kvcache)) { std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" @@ -620,7 +604,7 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor bias_host( bias.type == bias_enum::elementwise_bias - ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor alibi_slope_host( @@ -819,7 +803,7 @@ bool run(const ck_tile::ArgParser& arg_parser) << (is_rotary_interleaved ? "inter" : "half") << ")"; } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(1 < num_splits) { std::cout << ", num_splits:" << num_splits; @@ -850,6 +834,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else // fmha_fwd_traits or fmha_splitkv_traits { traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -859,6 +844,11 @@ bool run(const ck_tile::ArgParser& arg_parser) { traits.has_dropout = (p_drop > 0.0f); } + else if constexpr(std::is_same_v>) + { + traits.use_pagedkv = use_kvcache; + } } }; @@ -884,7 +874,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else return i_perm ? seqlen_knew : nhead_k * seqlen_knew; }(); - const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_bias = (i_perm ? max_seqlen_k : 1 * max_seqlen_k); const ck_tile::index_t stride_randval = (max_seqlen_k); const ck_tile::index_t stride_o_acc = (hdim_v); const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); @@ -909,7 +899,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return i_perm ? hdim_v * seqlen_knew : seqlen_knew; }(); const ck_tile::index_t nhead_stride_bias = - (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); @@ -925,7 +915,7 @@ bool run(const ck_tile::ArgParser& arg_parser) (0 < page_block_size ? (nhead_k * hdim_v * page_block_size) : (nhead_k * hdim_v * shape_seqlen_k)); const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); - const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); @@ -1007,6 +997,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.scale_p = scale_p; args.scale_o = scale_o; + args.logits_soft_cap = logits_soft_cap; + args.stride_bias = (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); args.stride_o = stride_o; @@ -1065,6 +1057,17 @@ bool run(const ck_tile::ArgParser& arg_parser) args.split_stride_lse_acc = split_stride_lse_acc; args.split_stride_o_acc = split_stride_o_acc; } + else if constexpr(std::is_same_v>) + { + args.block_table_ptr = + (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); + args.batch_stride_block_table = batch_stride_block_table; + args.page_block_size = page_block_size; + args.is_gappy = false; // use 'false' for flash-attention integration + + args.cache_batch_idx = + (use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr); + } } }; @@ -1086,7 +1089,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const float fwd_ave_time = [&] { #if CK_TILE_FMHA_FWD_SPLITKV_API - if(1 < num_splits || use_kvcache) + if(1 < num_splits && use_kvcache) { fmha_fwd_splitkv_traits fmha_splitkv_traits; init_traits(fmha_splitkv_traits); @@ -1096,6 +1099,18 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); } +#endif +#if CK_TILE_FMHA_FWD_PAGEDKV_API + if(use_kvcache) + { + fmha_fwd_pagedkv_traits fmha_pagedkv_traits; + init_traits(fmha_pagedkv_traits); + + fmha_fwd_pagedkv_args fmha_pagedkv_args; + init_args(fmha_pagedkv_args); + + return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + } #endif fmha_fwd_traits fmha_traits; init_traits(fmha_traits); @@ -1120,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec - << " GB/s" << std::flush; + << " GB/s" << std::flush << std::endl; if(do_validation == 0) { @@ -1251,7 +1266,7 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(i_perm) { k_host_ref.ForEach([&](auto& self, auto i) { @@ -1302,7 +1317,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API if(0 < page_block_size) { if(is_v_rowmajor) { if(i_perm) { @@ -1375,15 +1390,25 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); else - bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); }); // clang-format on // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 765c221a7b..df1e9e5699 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/host/device_prop.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/fmha.hpp" @@ -143,6 +144,8 @@ struct fmha_fwd_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -167,6 +170,7 @@ struct fmha_fwd_args ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; float p_drop; bool s_randval; @@ -175,6 +179,86 @@ struct fmha_fwd_args drop_seed_offset; }; +struct fmha_fwd_pagedkv_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* lse_ptr; + void* o_ptr; + + void* block_table_ptr; + ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr + ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr + bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not + // nullptr. + + const void* cache_batch_idx; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode: seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k + // group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + // + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.seqlen_k_ptr[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // + // when is_gappy=true: + // seqlen_k = kargs.seqlen_k_ptr[b] + // seqstart_k_ptr[b] now store local offset of each batch + // + // when is_gappy=false: + // seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b] + // or kargs.seqlen_k_ptr[b] + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + ck_tile::index_t min_seqlen_q; +}; + struct fmha_fwd_splitkv_args { const void* q_ptr; @@ -232,6 +316,8 @@ struct fmha_fwd_splitkv_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -308,6 +394,85 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_vnew; }; +struct fmha_batch_prefill_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + const void* seqstart_q_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + // SGLang-style page table + int32_t num_total_pages; + void* kv_indptr; + void* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + void* kv_last_page_lens; + ck_tile::index_t page_block_size; +#endif + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -333,6 +498,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -349,6 +515,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.window_size_left, args.window_size_right, args.mask_type, + args.min_seqlen_q, args.p_drop, args.s_randval, args.drop_seed_offset); @@ -371,6 +538,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -414,6 +582,114 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) } } +template +auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.min_seqlen_q); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + // FmhaKernel::PrintParameters(kargs, args.batch); + if constexpr(FmhaKernel::kIsGroupMode) + { + dim3 grids = FmhaKernel::GridSize( + args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr); + return ck_tile::make_tuple(kargs, grids); + } + else + { + dim3 grids = + FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false); + return ck_tile::make_tuple(kargs, grids); + } +} + template auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) { @@ -443,6 +719,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.is_gappy, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -485,6 +762,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.cache_batch_idx, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -618,6 +896,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template + bool kPadDv_, + bool kUseTrLoad_, + bool kSkipMinSeqlenQ_ = false> struct fmha_fwd_traits_ { static constexpr ck_tile::index_t HDim = HDim_; @@ -652,6 +1044,7 @@ struct fmha_fwd_traits_ static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; using FmhaMask = ck_tile::remove_cvref_t; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -661,6 +1054,8 @@ struct fmha_fwd_traits_ static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadD = kPadD_; static constexpr bool kPadDv = kPadDv_; + static constexpr bool kUseTrLoad = kUseTrLoad_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; template @@ -677,6 +1072,58 @@ template +struct fmha_fwd_pagedkv_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kIsPagedKV = kIsPagedKV_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; + static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; +}; + +template +float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args); + +template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -776,6 +1224,9 @@ struct fmha_fwd_appendkv_traits_ template float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); +template +float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -784,15 +1235,38 @@ struct fmha_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool has_dropout; bool do_fp8_static_quant; + bool skip_min_seqlen_q = false; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); +struct fmha_fwd_pagedkv_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + bool has_logits_soft_cap; + mask_enum mask_type; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum + bool has_lse = false; + bool use_pagedkv = true; + bool do_fp8_static_quant = false; + bool skip_min_seqlen_q = false; + // TODO: padding check is inside this api +}; + +float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits&, + fmha_fwd_pagedkv_args&, + const ck_tile::stream_config&); + struct fmha_fwd_splitkv_traits { int hdim_q; @@ -800,6 +1274,7 @@ struct fmha_fwd_splitkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -821,3 +1296,8 @@ struct fmha_fwd_appendkv_traits float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); + +using fmha_batch_prefill_traits = fmha_fwd_traits; +float fmha_batch_prefill(fmha_batch_prefill_traits, + fmha_batch_prefill_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0d35db14d4..0317330511 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -21,8 +21,7 @@ class HandlerId(IntEnum): ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) - if full_module_name not in sys.modules: - ops.append(importer.find_spec(module_name).loader.load_module(module_name)) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) unwanted_prefix = 'fmha_' handlers = dict( [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, @@ -30,7 +29,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -40,10 +39,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, mask_impl) + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) @@ -52,7 +51,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, mask_impl) + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -109,16 +108,25 @@ if __name__ == "__main__": " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + ) + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" ) args = parser.parse_args() api_list = args.direction.split(',') filter_list = args.filter.split(',') filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp old mode 100644 new mode 100755 index c77b700b16..b96482f535 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -21,6 +21,8 @@ enum class mask_enum struct mask_info { mask_enum type; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; ck_tile::index_t y, x; ck_tile::index_t left, right; // FA style SWA left/right @@ -42,6 +44,8 @@ struct mask_info ck_tile::index_t x_total = seqlen_k; ck_tile::index_t y_total = seqlen_q; mask_info tmp; + tmp.seqlen_q = seqlen_q; + tmp.seqlen_k = seqlen_k; auto found_0 = str.find(':'); if(found_0 != std::string::npos) { @@ -148,7 +152,22 @@ struct mask_info } return tmp; } - + ck_tile::index_t get_unmaskarea() const + { + if(type == mask_enum::no_mask) + return seqlen_q * seqlen_k; + ck_tile::index_t area = 0; + for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y) + { + ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast(0)); + ck_tile::index_t x_end = std::min(i_y + x, seqlen_k); + if(x_end > x_start) + { + area += (x_end - x_start); + } + } + return area; + } friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) { mi.serialize(os); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 599c595a75..88c16cceb6 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,14 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done - -for perm in 0 1 ; do - -$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 -$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 - -done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index b867cd6c07..dc2be933bd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -42,7 +42,6 @@ run_fp16_bf16_tests() { for prec in "fp16" "bf16" ; do for mode in 1 0 ; do for perm in 0 1 ; do - for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do for bias in "n" "e" "a" ; do @@ -51,16 +50,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done diff --git a/example/ck_tile/02_layernorm2d/CMakeLists.txt b/example/ck_tile/02_layernorm2d/CMakeLists.txt index fa69ac0f7a..07714f0fe2 100644 --- a/example/ck_tile/02_layernorm2d/CMakeLists.txt +++ b/example/ck_tile/02_layernorm2d/CMakeLists.txt @@ -25,7 +25,7 @@ add_custom_command( set(EXAMPLE_LAYERNORM2D_FWD "tile_example_layernorm2d_fwd") -message("adding example ${EXAMPLE_LAYERNORM2D_FWD}") +message(DEBUG "adding example ${EXAMPLE_LAYERNORM2D_FWD}") add_executable(${EXAMPLE_LAYERNORM2D_FWD} EXCLUDE_FROM_ALL layernorm2d_fwd.cpp) target_include_directories(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 0238a125dc..d77582630a 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; - static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; - static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); static constexpr ck_tile::index_t total_warps = - (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); // num of warps along m static constexpr ck_tile::index_t BlockWarps_M = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); - return total_warps * (warpSize / ThreadPerBlock_N_); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); } else { - // static_assert(warpSize % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N_ / warpSize); + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); } }(); @@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_ static constexpr ck_tile::index_t BlockWarps_N = []() { if constexpr(is_warp_per_row) { - static_assert(warpSize % ThreadPerBlock_N_ == 0); + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); return 1; } else { - static_assert(ThreadPerBlock_N_ % warpSize == 0); - return ThreadPerBlock_N_ / warpSize; + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); } }(); diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b72485222e..bdd5f2da1b 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 30cfee22f6..825cd6e522 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,18 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) -target_compile_options(tile_example_gemm_universal PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) +add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) +add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0") +target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index 4c16f13cef..c9e392dbd5 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -12,26 +12,29 @@ sh ../script/cmake-ck-dev.sh ../ make tile_example_gemm_basic -j # The memory bound pipeline on the gemm calculation make tile_example_gemm_universal -j +# The weight preshuffle pipeline on the gemm calculation +make tile_example_gemm_weight_preshuffle -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal` ## example ``` args: - -b batch size (default:1) -m m dimension (default:1024) -n n dimension (default:2048) -k k dimension (default:64) -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) + -b_layout Tensor B data layout (default: C) -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) - -e Absolute error tolerance (default:1e-5) - -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16) -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -split_k splitK value (default:1) + -init 0:random, 1:linear, 2:constant (default:1) + -persistent 0:non-persistent, 1:persistent (default:0) ``` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100755 new mode 100644 index 69051423fb..25781a4ae8 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -1,32 +1,25 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include -#include -#include -#include -#include - -#include "ck_tile/host.hpp" #include "gemm_utils.hpp" -template -float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) -{ - // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; + typename DsLayout, + typename CLayout, + bool Persistent, + typename CDEElementWise> +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) - constexpr int kBlockPerCu = 1; +{ + if constexpr(Persistent) + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 256; @@ -48,61 +41,89 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenGemmTraits = - ck_tile::TileGemmTraits; + using CodegenGemmTraits = ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; - if(!Kernel::IsSupportedArgument(kargs)) + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(MemoryOpSet{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(MemoryOpAtomicAdd{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_gemm_example.inc" template -int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +int run_gemm_example_prec_type(std::string a_layout, + std::string b_layout, + ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -111,13 +132,13 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a { if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Row{}, Col{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Row{}); } else { @@ -127,25 +148,25 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } else { - if(a_layout == "R" && b_layout == "R") + if(a_layout == "R" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Row{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Row{}, Col{}, Row{}); } - else if(a_layout == "R" && b_layout == "C") + else if(a_layout == "R" && b_layout == "R") { - return run_gemm_example_with_layouts( - argc, argv, Row{}, Col{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Row{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "R") { - return run_gemm_example_with_layouts( - argc, argv, Col{}, Row{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Col{}, Row{}, Row{}); } else if(a_layout == "C" && b_layout == "C") { - return run_gemm_example_with_layouts( - argc, argv, Col{}, Col{}, Row{}); + return run_gemm_example_with_layouts( + arg_parser, Col{}, Col{}, Row{}); } else { @@ -154,47 +175,67 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a } } -int run_gemm_example(int argc, char* argv[]) +int run_gemm_example(ck_tile::ArgParser& arg_parser) { - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return -1; - std::string data_type = arg_parser.get_str("prec"); std::string a_layout = arg_parser.get_str("a_layout"); std::string b_layout = arg_parser.get_str("b_layout"); if(data_type == "fp16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "bf16") { - return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); + return run_gemm_example_prec_type(a_layout, b_layout, arg_parser); } else if(data_type == "fp8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); } else if(data_type == "bf8") { return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + a_layout, b_layout, arg_parser); + } + else if(data_type == "i8") + { + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); } - -#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - return run_gemm_example_prec_type( - a_layout, b_layout, argc, argv); + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_example_prec_type( + a_layout, b_layout, arg_parser); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } } -#endif else { throw std::runtime_error("Unsupported data type for this operation !!!"); } } -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + try + { + return !run_gemm_example(arg_parser); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp new file mode 100644 index 0000000000..a4a8039288 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -0,0 +1,1009 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +/** + * @brief Tile partitioner with output offset support. + * + * This partitioner extends the spatially local tile partitioner to support + * split-K reduction by providing workspace output offset calculation. Each K-split + * writes to a separate slice of the workspace: workspace[k_id * M * N]. + */ +template +struct GemmSplitKTilePartitioner + : public ck_tile::GemmSpatiallyLocalTilePartitioner +{ + using Base = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Inherit constructors and methods + using Base::Base; + using Base::GetLoopNum; + + /** + * @brief Calculate output pointer offset for split-K reduction. + * + * @param kargs Kernel arguments. + * @param k_id Current K-split ID (from blockIdx.z or calculated k_batch). + * @return ck_tile::index_t The offset for this K-split. + */ + template + CK_TILE_HOST_DEVICE static ck_tile::index_t GetOutputOffset(const KernelArgs& kargs, + ck_tile::index_t k_id) noexcept + { + // Each K-split gets its own M*N workspace slice + return (kargs.k_batch > 1) ? (k_id * kargs.M * kargs.N) : 0; + } +}; + +/** + * @brief Extended GEMM host arguments for two-stage split-K implementation + * + * This structure supports the two-stage split-K approach where: + * 1. Stage 1: GEMM writes partial results to workspace memory + * 2. Stage 2: Reduction kernel sums workspace results to final output + * + * The base class e_ptr points to workspace, while final_output_ptr points to the actual output + */ +struct GemmSplitKHostArgs : public ck_tile::GemmHostArgs +{ + using BaseArgs = ck_tile::GemmHostArgs; + + CK_TILE_HOST GemmSplitKHostArgs() = default; + CK_TILE_HOST GemmSplitKHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* workspace_ptr_, // Workspace for partial results + void* e_ptr_, // Final output destination + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t workspace_stride_, + ck_tile::index_t stride_E_) + : BaseArgs(a_ptr_, + b_ptr_, + workspace_ptr_, // Base e_ptr = workspace_ptr + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + workspace_stride_), + final_output_ptr(e_ptr_), + final_stride_E(stride_E_) + { + } + + void* final_output_ptr; // Pointer to final output tensor + ck_tile::index_t final_stride_E; // Stride for final output tensor +}; + +/** + * @brief Stage 1: GEMM kernel that writes partial split-K results to workspace + * + * This function performs the matrix multiplication with split-K, where each + * K-split writes its partial result to a separate section of the workspace. + * + * Workspace layout: [k_batch, M, N] where each [M, N] slice contains + * partial results for one K-split. + * + * @param args Extended arguments containing workspace and final output pointers + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = GemmSplitKTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + // Create base GEMM arguments pointing to workspace instead of final output + // The workspace will store partial results from each K-split + ck_tile::GemmHostArgs base_args(args.a_ptr, + args.b_ptr, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_E); + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + // For workspace mode, always use SET operation since each K-split writes to separate memory + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +/** + * @brief Stage 2: Reduction kernel that sums partial split-K results to final output + * + * This function reduces the partial results stored in workspace memory by stage 1. + * It sums across the k_batch dimension to produce the final GEMM result. + * + * Workspace layout: [k_batch, M, N] -> Final output: [M, N] + * + * @tparam CDataType Output data type + * @tparam ComputeDataType Computation precision for reduction + * @tparam ELayout Memory layout of output tensor + * @param args Extended arguments containing workspace and output information + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + const ck_tile::index_t reduce_dim_size = args.k_batch; // Number of partial results to reduce + // Calculate output size based on the final output tensor dimensions + const ck_tile::index_t output_size = args.M * args.N; + + // Workspace layout: [k_batch, M, N] where each [M, N] slice has the same layout as final output + // The workspace strides need to account for the layout of the final output tensor + auto workspace_shape = ck_tile::make_tuple(args.k_batch, args.M, args.N); + auto workspace_strides = + ck_tile::make_tuple(args.M * args.N, // k_batch stride: jump to next K split + args.final_stride_E, // stride same as final output stride + 1); + + // Define kept and reduced dimensions + constexpr auto kept_dim = ck_tile::sequence<1, 2>{}; // Keep M, N dimensions + constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension + + using ReduceOp = ck_tile::ReduceOp::Add; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + + using Shape = ck_tile::Reduce2dShape; + using Problem = + ck_tile::Reduce2dProblem; + using Kernel = ck_tile::Reduce; + + if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides)) + { + throw std::runtime_error("Wrong! Reduction arguments not supported!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 2 - Launching Reduction kernel" << '\n' + << "workspace shape: [" << args.k_batch << ", " << args.M << ", " << args.N << "]" + << '\n' + << "output shape: [" << args.M << ", " << args.N << "]" << '\n' + << "grid size: " << kGridSize << std::endl; + } + + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // LDS size + static_cast(args.e_ptr), // workspace input + static_cast(args.final_output_ptr), // final output + workspace_shape, + workspace_strides, + kept_dim, + reduce_dims)); + + return ave_time; +} + +/** + * @brief Orchestrator for two-stage split-K GEMM implementation + * + * This function coordinates the two-stage approach: + * 1. Stage 1: Execute GEMM with each K-split writing to workspace + * 2. Stage 2: Reduce workspace results to final output (if k_batch > 1) + * + * @param args Extended arguments for two-stage execution + * @param s Stream configuration + * @return Total execution time (GEMM + Reduction) + */ +template +float gemm_splitk_two_stage(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + float gemm_time = 0.0f; + float reduce_time = 0.0f; + + if(s.log_level_ > 0) + { + std::cout << "Starting Two-Stage GEMM+SplitK with k_batch=" << args.k_batch << std::endl; + std::cout << "Workspace size: " << args.k_batch << " x " << args.M << " x " << args.N + << " = " << args.k_batch * args.M * args.N * sizeof(CDataType) << " bytes" + << std::endl; + } + + // Stage 1: GEMM to workspace + gemm_time = gemm_stage1(args, s); + + // Synchronize before stage 2 + auto sync_result = hipStreamSynchronize(s.stream_id_); + if(sync_result != hipSuccess) + { + throw std::runtime_error("Stream synchronization failed"); + } + + // Stage 2: Reduction from workspace to final output (if needed) + if(args.k_batch > 1) + { + // Use appropriate precision for reduction computations + using ComputeDataType = std::conditional_t< + std::is_same_v, + float, + std::conditional_t, float, CDataType>>; + reduce_time = reduce_stage2(args, s); + } + else + { + // Single K-split: simple copy from workspace to final output + auto copy_result = hipMemcpyAsync(args.final_output_ptr, + args.e_ptr, + args.M * args.N * sizeof(CDataType), + hipMemcpyDeviceToDevice, + s.stream_id_); + if(copy_result != hipSuccess) + { + throw std::runtime_error("Memory copy failed"); + } + } + + if(s.log_level_ > 0) + { + std::cout << "GEMM stage time: " << gemm_time << " ms" << std::endl; + if(args.k_batch > 1) + { + std::cout << "Reduction stage time: " << reduce_time << " ms" << std::endl; + } + std::cout << "Total time: " << gemm_time + reduce_time << " ms" << std::endl; + } + + return gemm_time + reduce_time; +} + +/** + * @brief High-level interface for two-stage split-K GEMM execution + * + * @param a_m_k_dev_buf Input matrix A device buffer + * @param b_k_n_dev_buf Input matrix B device buffer + * @param c_m_n_dev_buf Output matrix C device buffer + * @param M Matrix M dimension + * @param N Matrix N dimension + * @param K Matrix K dimension + * @param stride_A Memory stride for matrix A + * @param stride_B Memory stride for matrix B + * @param stride_C Memory stride for matrix C + * @param kbatch Number of K-splits for split-K execution + * @param n_warmup Number of warmup iterations + * @param n_repeat Number of repeat iterations for benchmarking + * @param persistent Whether to use persistent kernel execution + * @return Average execution time in milliseconds + */ +template +float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat, + bool persistent) +{ + // Calculate workspace size: kbatch * M * N elements + const ck_tile::index_t workspace_size = kbatch * M * N * sizeof(CDataType); + const ck_tile::index_t workspace_stride = stride_C; // Stride for k_batch dimension + + // Allocate workspace memory + ck_tile::DeviceMem workspace_buf(workspace_size); + workspace_buf.SetZero(); + + // Create extended args for two-stage approach + GemmSplitKHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + workspace_buf.GetDeviceBuffer(), // workspace_ptr (used as e_ptr for stage 1) + c_m_n_dev_buf.GetDeviceBuffer(), // final_output_ptr + kbatch, // k_batch + M, + N, + K, // dimensions + stride_A, + stride_B, // input strides + workspace_stride, // workspace stride + stride_C // final output stride + }; + + float ave_time; + ck_tile::stream_config config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}; + + if(persistent) + { + ave_time = gemm_splitk_two_stage(args, config); + } + else + { + ave_time = gemm_splitk_two_stage(args, config); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Two-Stage GEMM+SplitK with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes" + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; + + return ave_time; +} + +// Two-stage implementation of run_gemm_example_with_layouts +template +int run_gemm_example_with_layouts_two_stage(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); + + const bool preshuffle = GemmConfig::Preshuffle; + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + if constexpr(preshuffle) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(!preshuffle && GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + static_assert(!GemmConfig::PermuteA, "Not implemented"); + + if constexpr(preshuffle) + { + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + // shuffled buffer B for device implementation + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } + else + { + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + std::cout << "Using Workspace Split-K Mode (Two-Stage with Reduction)" << std::endl; + // Use the new two-stage approach + invoke_gemm_splitk_two_stage, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + if constexpr(GemmConfig::Preshuffle) + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + // memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && std::is_same_v) + { + throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); + } + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + // Use new two-stage approach for both int4 and other data types + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Row{}, Row{}); + } + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + return 0; +} + +template