diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bd597344ea..af36f492ba 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd @ddembeckAMD +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @shumway @vidyasagar-amd +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 557afe2d84..cc66fdbfe8 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -42,6 +42,24 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: file=sys.stderr, ) return None + +GITHUB_WORKFLOWS_CI_PATTERNS = [ + "therock*", +] + +def is_path_workflow_file_related_to_ci(path: str) -> bool: + return any( + fnmatch.fnmatch(path, ".github/workflows/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) or any( + fnmatch.fnmatch(path, ".github/scripts/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) + +def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool: + if paths is None: + return False + return any(is_path_workflow_file_related_to_ci(p) for p in paths) # Paths matching any of these patterns are considered to have no influence over # build or test workflows so any related jobs can be skipped if all paths @@ -82,12 +100,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: ) other_paths = paths_set - github_workflows_paths + related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths) contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) print("should_ci_run_given_modified_paths findings:") print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") - if contains_other_non_skippable_files: + if related_to_ci: + print("Enabling build jobs since a related workflow file was modified") + return True + elif contains_other_non_skippable_files: print("Enabling TheRock CI jobs since a non-skippable path was modified") return True else: diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 7db124d2a1..271c6376ca 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -26,31 +26,49 @@ jobs: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} TEATIME_FORCE_INTERACTIVE: 0 AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini + CACHE_DIR: ${{ github.workspace }}/.container-cache + # The ccache.conf will be written by setup_ccache.py before this gets used. + CCACHE_CONFIGPATH: ${{ github.workspace }}/.ccache/ccache.conf steps: + - name: "Checking out repository for rocm-libraries" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/rocm-libraries" + - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: "composable_kernel" - name: Checkout TheRock repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit path: "TheRock" + - name: Setup ccache + run: | + ./TheRock/build_tools/setup_ccache.py \ + --config-preset "github-oss-presubmit" \ + --dir "$(dirname $CCACHE_CONFIGPATH)" \ + --local-path "$CACHE_DIR/ccache" + echo "namespace = ext_composable_kernel" >> $CCACHE_CONFIGPATH + echo "[*] ccache_config contents:" + cat $CCACHE_CONFIGPATH + - name: Runner Health Settings run: | - df -h - cmake --version - echo "Installed Python versions:" - ls -d /opt/python - echo "python: $(which python), python3: $(which python3)" - echo "Git version: $(git --version)" - git config --global --add safe.directory $PWD - git config fetch.parallel 10 - + ./TheRock/build_tools/health_status.py + - name: Fetch sources run: | - ./TheRock/build_tools/fetch_sources.py --jobs 12 + ./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks + + - name: Patch rocm-libraries + run: | + git config --global --add safe.directory '*' + git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | @@ -84,6 +102,10 @@ jobs: echo "Artifacts:" echo "----------" du -h -d 1 TheRock/build/artifacts + echo "CCache Stats:" + echo "-------------" + ccache -s -v + tail -v -n +1 .ccache/compiler_check_cache/* > TheRock/build/logs/ccache_compiler_check_cache.log - name: Configure AWS Credentials for non-forked repos if: ${{ always() && !github.event.pull_request.head.repo.fork }} @@ -92,32 +114,14 @@ jobs: aws-region: us-east-2 role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external - - name: Create Logs index Files and upload logs + - name: Post Build Upload if: always() run: | - python3 TheRock/build_tools/github_actions/create_log_index.py \ - --build-dir=TheRock/build \ - --amdgpu-family=${{ env.AMDGPU_FAMILIES }} - - python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ - --build-dir=TheRock/build \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} - - - name: Upload artifacts - run: | - python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build - - - name: Add Links to Job Summary - if: always() - run: | - python TheRock/build_tools/github_actions/upload_build_summary.py \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build + --build-dir TheRock/build \ + --upload therock-test-linux: name: "Test" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 3232652b6b..40a3b0bec8 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -56,7 +56,14 @@ jobs: uses: ./.github/workflows/therock-ci-linux.yml secrets: inherit with: - cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + cmake_options: >- + -DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON + -DTHEROCK_ENABLE_MIOPEN=ON + -DTHEROCK_ENABLE_ALL=OFF + -DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON + -DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel + -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON + -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" test_runs_on: "linux-mi325-1gpu-ossci-rocm" diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml new file mode 100644 index 0000000000..068dbe3033 --- /dev/null +++ b/.github/workflows/therock-test-component.yml @@ -0,0 +1,71 @@ +name: Test component + +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + component: + type: string + + +permissions: + contents: read + +jobs: + test_component: + name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})' + runs-on: ${{ inputs.test_runs_on }} + container: + image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }} + options: --ipc host + --group-add video + --device /dev/kfd + --device /dev/dri + --group-add 110 + --env-file /etc/podinfo/gha-gpu-isolation-settings + strategy: + fail-fast: false + matrix: + # The shard array is based on "total_shards" from "fetch_test_configurations.py" + # The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards) + shard: ${{ fromJSON(inputs.component).shard_arr }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}" + OUTPUT_ARTIFACTS_DIR: "./build" + THEROCK_BIN_DIR: "./build/bin" + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + steps: + - name: Checkout Repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} + + - name: Test + timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }} + env: + SHARD_INDEX: ${{ matrix.shard }} + TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }} + run: | + ${{ fromJSON(inputs.component).test_script }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 37ddd399ad..54e068eb3d 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -37,41 +37,17 @@ jobs: test_components: name: 'Test ${{ matrix.components.job_name }}' - runs-on: ${{ inputs.test_runs_on }} - needs: configure_test_matrix + needs: [configure_test_matrix] # skip tests if no test matrix to run if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} strategy: fail-fast: false matrix: components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} - defaults: - run: - shell: bash - env: - VENV_DIR: ${{ github.workspace }}/.venv - ARTIFACT_RUN_ID: "${{ github.run_id }}" - OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build - THEROCK_BIN_DIR: "./build/bin" - steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: "ROCm/TheRock" - - - name: Run setup test environment workflow - uses: './.github/actions/setup_test_environment' - with: - ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} - OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} - VENV_DIR: ${{ env.VENV_DIR }} - FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} - PLATFORM: ${{ inputs.platform }} - IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} - - - name: Test - timeout-minutes: ${{ matrix.components.timeout_minutes }} - run: | - if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi - ${{ matrix.components.test_script }} + uses: './.github/workflows/therock-test-component.yml' + with: + artifact_run_id: ${{ github.run_id }} + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: ${{ inputs.platform }} + component: ${{ toJSON(matrix.components) }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 664c5219e2..2d936d3a48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,12 +6,12 @@ repos: entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - - id: copyright-year-checker - name: copyright-year-checker - entry: script/check_copyright_year.sh - verbose: false - language: script - types: [c++] + # - id: copyright-year-checker + # name: copyright-year-checker + # entry: script/check_copyright_year.sh + # verbose: false + # language: script + # types: [c++] - id: remove-exec-bit name: Remove executable bit from non-executable files entry: script/remove_exec_bit.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index 38669385f3..9de78f3043 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,8 +5,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added a compute async pipeline in the CK TILE universal GEMM on gfx950 +* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. +* Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). @@ -15,6 +19,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added support for Multiple D GEMM +* Added support for Multiple ABD GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. @@ -29,6 +34,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. +* Added support for f32 to FMHA (fwd/bwd). +* Added tensor-wise quantization for CK_TILE GEMM. +* Added support for batched contraction kernel. +* Added pooling kernel in CK_TILE ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index ddadfb0353..f4d3a83c34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,7 +220,10 @@ rocm_check_target_ids(SUPPORTED_GPU_TARGETS message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") -if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") +# Cache SUPPORTED_GPU_TARGETS for debug +set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") + +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") @@ -339,6 +342,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) +option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -352,6 +356,11 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if (ENABLE_JSON_DUMP) + add_compile_definitions(CK_ENABLE_JSON_DUMP) + message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Dockerfile b/Dockerfile index 6f5cd0115d..07327442fe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,23 @@ + FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.4.1 +ARG ROCMVERSION=7.0.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive # Add rocm repository RUN set -xe && \ - apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ - curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN if [ "$ROCMVERSION" != "6.5" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ - fi - -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ - amdgpu-install -y --usecase=rocm --no-dkms +RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ + apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ + apt update && \ + apt install python3-setuptools python3-wheel -y && \ + apt install rocm-dev -y ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -45,7 +41,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libelf-dev \ libnuma-dev \ libpthread-stubs0-dev \ - llvm-amdgpu \ mpich \ net-tools \ pkg-config \ @@ -61,17 +56,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- zip \ libzstd-dev \ openssh-server \ - clang-format-12 \ clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ rm -rf amdgpu-install* && \ -# Remove unnecessary rocm components that take a lot of space - apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt - #Install latest ccache -RUN git clone https://github.com/ccache/ccache.git && \ + git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ diff --git a/Dockerfile.aiter b/Dockerfile.aiter index 245e39fb75..b61c1e41a5 100644 --- a/Dockerfile.aiter +++ b/Dockerfile.aiter @@ -1,10 +1,8 @@ -ARG BASE_DOCKER="rocm/pytorch:latest" +ARG BASE_DOCKER="rocm/composable_kernel-private:ck_aiter_base" FROM $BASE_DOCKER ARG AITER_BRANCH="main" ARG CK_AITER_BRANCH="develop" -RUN groupadd -g 109 render && \ - usermod -u 1001 jenkins && \ - groupmod -g 1001 jenkins && \ +RUN groupadd irc && \ pip install pandas zmq einops && \ pip install numpy==1.26.2 && \ sudo mkdir /home/jenkins && \ diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 0306057e45..47bd8294b6 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index 2b9ea200f0..11a9d9eb74 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -46,6 +46,58 @@ def runShell(String command){ return (output != "") } +def shouldRunCICheck() { + // Define patterns for files that should not trigger CI + def skipFilePatterns = [ + /^\.github\/.*/, // GitHub workflow files + /^docs\/.*/, // Documentation files + /^LICENSE$/, // License file + /^.*\.gitignore$/, // Git ignore files + /.*\.md$/ // Markdown files + ] + + try { + // Get the list of changed files + def changedFiles = sh( + returnStdout: true, + script: ''' + if [ "$CHANGE_ID" != "" ]; then + # For PR builds, compare against target branch + git diff --name-only origin/$CHANGE_TARGET...HEAD + else + # For regular builds, compare against previous commit + git diff --name-only HEAD~1..HEAD + fi + ''' + ).trim().split('\n') + + if (changedFiles.isEmpty() || (changedFiles.size() == 1 && changedFiles[0].trim().isEmpty())) { + echo "No changed files detected - this might be a manual trigger or merge commit, running CI for safety" + return true + } + + echo "Changed files: ${changedFiles.join(', ')}" + + // Check if any changed files are not in the skip patterns + def hasFilesRequiringCI = changedFiles.any { file -> + !skipFilePatterns.any { pattern -> + file ==~ pattern + } + } + + if (hasFilesRequiringCI) { + echo "Found files that require CI" + return true + } else { + echo "Only non-relevant files changed, skipping CI" + return false + } + } catch (Exception e) { + echo "Error checking changed files: ${e.getMessage()}, running CI by default" + return true + } +} + def getBaseDockerImageName(){ def img if (params.USE_CUSTOM_DOCKER != ""){ @@ -53,7 +105,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -149,7 +201,7 @@ def getDockerImage(Map conf=[:]){ image = conf.get("docker_name", "") echo "Using legacy docker: ${image}" } - else if ( params.BUILD_GFX950 && conf.get("docker_name", "") != "" ){ + else if ( (params.BUILD_GFX950 || params.RUN_CK_TILE_FMHA_TESTS) && conf.get("docker_name", "") != "" ){ image = conf.get("docker_name", "") echo "Using special docker: ${image}" } @@ -157,9 +209,9 @@ def getDockerImage(Map conf=[:]){ image = getDockerImageName() echo "Using default docker: ${image}" } - //Check if image exists + //Check if image exists def retimage - try + try { echo "Pulling image: ${image}" retimage = docker.image("${image}") @@ -186,11 +238,11 @@ def buildDocker(install_prefix){ dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " } else if(params.RUN_AITER_TESTS){ - image_name = "rocm/composable_kernel:ck_aiter" + image_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter --build-arg AITER_BRANCH='${params.aiter_branch}' --build-arg CK_AITER_BRANCH='${params.ck_aiter_branch}' . " } else if(params.RUN_PYTORCH_TESTS){ - image_name = "rocm/composable_kernel:ck_pytorch" + image_name = "${env.CK_DOCKERHUB}:ck_pytorch" dockerArgs = dockerArgs + " --no-cache -f Dockerfile.pytorch --build-arg CK_PYTORCH_BRANCH='${params.ck_pytorch_branch}' . " } else{ @@ -232,7 +284,7 @@ def cmake_build(Map conf=[:]){ def setup_args = conf.get("setup_args","") // make sure all unit tests always run on develop branch def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS - + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } @@ -357,7 +409,7 @@ def cmake_build(Map conf=[:]){ "build_cmd", "${build_envs} ninja -j${nt} ${config_targets}" ) - + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -449,7 +501,7 @@ def buildHipClangJob(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts if ( params.BUILD_INSTANCES_ONLY ){ dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" @@ -476,7 +528,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 20, unit: 'HOURS') { @@ -515,7 +567,7 @@ def Build_CK(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -538,7 +590,7 @@ def Build_CK(Map conf=[:]){ def image def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -716,10 +768,10 @@ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use older image that has user jenkins - def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" + def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -728,7 +780,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -827,7 +879,7 @@ def run_aiter_tests(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use the latest pytorch image - def image = "rocm/composable_kernel:ck_aiter" + def image = "${env.CK_DOCKERHUB_PRIVATE}:ck_aiter" def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" def variant = env.STAGE_NAME def retimage @@ -836,7 +888,7 @@ def run_aiter_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -852,13 +904,14 @@ def run_aiter_tests(Map conf=[:]){ } withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 2, unit: 'HOURS'){ + timeout(time: 5, unit: 'HOURS'){ try{ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -885,7 +938,7 @@ def run_pytorch_tests(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm //use the latest pytorch-nightly image - def image = "rocm/composable_kernel:ck_pytorch" + def image = "${env.CK_DOCKERHUB}:ck_pytorch" def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" def variant = env.STAGE_NAME def retimage @@ -894,7 +947,7 @@ def run_pytorch_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -930,13 +983,14 @@ def run_pytorch_tests(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false - 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : "" +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true + 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true + 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true + 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" pipeline { agent none @@ -956,20 +1010,20 @@ pipeline { defaultValue: '', description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( - name: 'ROCMVERSION', - defaultValue: '6.4.1', - description: 'Specify which ROCM version to use: 6.4.1 (default).') + name: 'ROCMVERSION', + defaultValue: '7.0.1', + description: 'Specify which ROCM version to use: 7.0.1 (default).') string( - name: 'COMPILER_VERSION', - defaultValue: '', + name: 'COMPILER_VERSION', + defaultValue: '', description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).') string( - name: 'COMPILER_COMMIT', - defaultValue: '', + name: 'COMPILER_COMMIT', + defaultValue: '', description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( - name: 'BUILD_COMPILER', - defaultValue: '/opt/rocm/llvm/bin/clang++', + name: 'BUILD_COMPILER', + defaultValue: '/opt/rocm/llvm/bin/clang++', description: 'Build CK with /opt/rocm/bin/hipcc, /llvm-project/build/bin/clang++, or with /opt/rocm/llvm/bin/clang++ (default).') booleanParam( name: "RUN_FULL_QA", @@ -1033,12 +1087,12 @@ pipeline { description: "Build CK and run tests on gfx90a (default: ON)") booleanParam( name: "BUILD_GFX942", - defaultValue: false, - description: "Build CK and run tests on gfx942 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx942 (default: ON)") booleanParam( name: "BUILD_GFX950", - defaultValue: false, - description: "Build CK and run tests on gfx950 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx950 (default: ON)") booleanParam( name: "BUILD_GFX10", defaultValue: true, @@ -1091,6 +1145,10 @@ pipeline { name: 'ck_aiter_branch', defaultValue: 'develop', description: 'Specify which branch of CK to test with AITER (default: develop)') + booleanParam( + name: "FORCE_CI", + defaultValue: false, + description: "Force CI to run even when only non-relevant files are changed (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -1104,7 +1162,20 @@ pipeline { DOCKER_BUILDKIT = "1" } stages{ + stage("Determine CI Execution") { + agent{ label rocmnode("nogpu") } + steps { + script { + env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) + echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" + } + } + } stage("Build Docker"){ + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } @@ -1116,6 +1187,11 @@ pipeline { } } stage("Static checks") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + expression { params.RUN_CPPCHECK.toBoolean() } + } parallel{ stage('Clang Format and Cppcheck') { when { @@ -1125,16 +1201,16 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1155,16 +1231,17 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \ + \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1175,6 +1252,10 @@ pipeline { } stage("Run Pytorch Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Pytorch Tests on gfx942") @@ -1193,6 +1274,10 @@ pipeline { } stage("Run AITER Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run AITER Tests on gfx942") @@ -1207,10 +1292,26 @@ pipeline { cleanWs() } } + stage("Run AITER Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_AITER_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950")} + steps{ + run_aiter_tests() + cleanWs() + } + } } } stage("Run Grouped Conv Large Case Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Grouped Conv Large Case Tests on gfx90a") @@ -1235,6 +1336,10 @@ pipeline { } stage("Run Comprehensive Convolution Dataset Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Comprehensive Dataset Tests on gfx90a") @@ -1267,6 +1372,10 @@ pipeline { } stage("Run Codegen Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Codegen Tests on gfx90a") @@ -1278,7 +1387,7 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_PREFIX_PATH=/opt/rocm ../codegen && \ make -j64 check""" } steps{ @@ -1290,6 +1399,10 @@ pipeline { } stage("Run CK_TILE_FMHA Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run CK_TILE_FMHA Tests on gfx90a") @@ -1321,7 +1434,7 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ - make -j64 tile_example_fmha_fwd tile_example_fmha_bwd && \ + make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ cd ../ && example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ } @@ -1330,10 +1443,33 @@ pipeline { cleanWs() } } + stage("Run CK_TILE_FMHA Tests on gfx950") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_FMHA_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx950") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \ + make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ + cd ../ && + example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } stage("Run TILE_ENGINE_GEMM Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run TILE_ENGINE_GEMM Tests on gfx90a") @@ -1353,10 +1489,15 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_preshuffle_all && \ + python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ @@ -1388,10 +1529,15 @@ pipeline { -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -D GEMM_MULTI_D_DATATYPE="fp16" \ -D GEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \ + -D GEMM_PRESHUFFLE_DATATYPE="fp16;fp8" \ + -D GEMM_PRESHUFFLE_LAYOUT="rcr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ ninja -j64 benchmark_gemm_all && \ python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_preshuffle_all && \ + python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \ ./bin/benchmark_gemm_multi_d_fp16_rrrr && \ ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \ @@ -1406,11 +1552,45 @@ pipeline { cleanWs() } } + stage("Run TILE_ENGINE_GEMM Tests on gfx1201") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx1201") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx1201" \ + -D GEMM_DATATYPE="fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -DGEMM_CONFIG_FILE=gfx120x_config.json \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_all && \ + python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp16_ccr """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } stage("Build CK and run Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Build CK with RHEL8") @@ -1494,7 +1674,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1549,7 +1729,7 @@ pipeline { agent{ label rocmnode("gfx942") } steps{ script { - def execute_args = params.NINJA_FTIME_TRACE ? + def execute_args = params.NINJA_FTIME_TRACE ? """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ @@ -1558,8 +1738,8 @@ pipeline { -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") + + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB}:ck_ub24.04_rocm7.0.1") } cleanWs() } @@ -1585,13 +1765,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx1101") + stage("Build CK and run Tests on gfx11") { when { beforeAgent true expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx1101") } + agent{ label 'miopen && (gfx1101 || gfx1100)' } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -1628,6 +1808,17 @@ pipeline { } } } + post { + success { + script { + // Report the parent stage build ck and run tests status + def variant = env.STAGE_NAME + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Reporting success status for build ck and run tests" + } + } + } + } } stage("Process Performance Test Results") { @@ -1645,6 +1836,22 @@ pipeline { } } } + post { + success { + script { + // Report the skipped parent's stage status + def parentVariant = "Process Performance Test Results" + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${parentVariant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Process Performance Test Results stage skipped." + } + // Report the skipped stage's status + def variant = "Process results" + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Process Performance Test Results stage skipped." + } + } + } + } } } } diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index f27e557cc3..21f6e652b8 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,7 @@ else() endif() if (GPU_TARGETS) - if (GPU_TARGETS MATCHES "gfx9") + if (GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 6587f4c4be..41e2fa2cc0 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -12,6 +12,17 @@ FetchContent_Declare( GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) +FetchContent_Populate(GTest) + +# Patch googlemock/CMakeLists.txt to fix invalid include path +set(GMOCK_CMAKE "${gtest_SOURCE_DIR}/googlemock/CMakeLists.txt") +file(READ "${GMOCK_CMAKE}" GMOCK_CMAKE_CONTENT) +string(REPLACE [[gtest_SOURCE_DIR}/include]] + [[gtest_SOURCE_DIR}/googletest/include]] + GMOCK_CMAKE_CONTENT + "${GMOCK_CMAKE_CONTENT}") +file(WRITE "${GMOCK_CMAKE}" "${GMOCK_CMAKE_CONTENT}") + # Suppress ROCMChecks WARNING on GoogleTests set(ROCM_DISABLE_CHECKS FALSE) macro(rocm_check_toolchain_var var access value list_file) @@ -24,7 +35,7 @@ if(WIN32) set(gtest_force_shared_crt ON CACHE_INTERNAL "") endif() -set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(BUILD_GMOCK ON CACHE INTERNAL "") set(INSTALL_GTEST OFF CACHE INTERNAL "") # Store the current value of BUILD_SHARED_LIBS @@ -32,15 +43,12 @@ set(__build_shared_libs ${BUILD_SHARED_LIBS}) set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") set(ROCM_DISABLE_CHECKS TRUE) -FetchContent_MakeAvailable(GTest) +add_subdirectory(${gtest_SOURCE_DIR} ${gtest_BINARY_DIR}) set(ROCM_DISABLE_CHECKS FALSE) # Restore the old value of BUILD_SHARED_LIBS set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) -set(BUILD_GMOCK OFF CACHE INTERNAL "") -set(INSTALL_GTEST OFF CACHE INTERNAL "") - set(GTEST_CXX_FLAGS -Wno-undef -Wno-reserved-identifier @@ -71,3 +79,12 @@ target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) +if(TARGET gmock) + target_compile_options(gmock PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock PRIVATE GTEST_HAS_SEH=0) +endif() + +if(TARGET gmock_main) + target_compile_options(gmock_main PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock_main PRIVATE GTEST_HAS_SEH=0) +endif() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2b2e6e2949..80429a781b 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -12,6 +12,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -27,7 +28,7 @@ add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers) +target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 61f3ba5351..03bde86421 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -44,8 +44,7 @@ list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllv example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS}) example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS}) - -list(APPEND gpu_list gfx942 gfx950) +list(APPEND gpu_list gfx942 gfx950 gfx1200 gfx1201 gfx12-generic) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) @@ -89,7 +88,14 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) + set(target 1) + endif() +endforeach() +list(APPEND gpu_list gfx90a gfx942 gfx950 gfx1200 gfx1201 gfx12-generic) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_gemm_xdl_bf16_streamk_v3 gemm_xdl_bf16_streamk_v3.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_streamk_v3) @@ -99,6 +105,16 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) + set(target 1) + endif() +endforeach() + add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 434f549443..e482953e46 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -310,10 +310,14 @@ bool parse_cmd_args(int argc, return true; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 6cfff30dbd..0440e8246b 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp index 7178ad46b9..9b1d756f85 100644 --- a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -199,9 +199,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl; return true; } diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 414683ffdf..66a0d98238 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,7 +37,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 4, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance1; diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index a996d034e6..2b04d14a11 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -30,7 +30,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeType>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 2, 16, 16, 0, + 1, 1, S<1, 16, 1, 16>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, CDataType, CDataType, PermuteA, PermuteB>; // clang-format on @@ -281,9 +281,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl; return true; } diff --git a/example/01_gemm/gemm_xdl_fp16_v2.cpp b/example/01_gemm/gemm_xdl_fp16_v2.cpp index ecd3b7be5d..59c059d014 100644 --- a/example/01_gemm/gemm_xdl_fp16_v2.cpp +++ b/example/01_gemm/gemm_xdl_fp16_v2.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -33,13 +33,13 @@ using DeviceGemmInstance = 2, 256, 256, 256, 32, 8, 4, - 32, 32, - 4, 4, + 16, 16, + 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v1>; // clang-format on diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 0c51a58037..af9b7978f5 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -31,7 +31,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // this instance has been tested working on gfx950 // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on @@ -55,4 +55,12 @@ using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm 32, fp8 -> 16, fp16 -> 8 // clang-format off #if 0 using DeviceGemmV2Instance = @@ -56,14 +56,14 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 256, - 128, 16, 32, - 32, 32, - 4, 4, + 128, 16, KPack, + 16, 16, + 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>; #endif @@ -160,7 +160,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto gemm = DeviceGemmV2Instance{}; // weight pre-shuffle - int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8 int NLane = gemm.GetPreShuffleParameters(); int KLane = 64 / NLane; @@ -269,9 +268,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx12 only" << std::endl; return true; } diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index 0575314dff..0e6503d21f 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -38,14 +38,14 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, - KPerBlock, 16, 32, - 32, 32, - 2, 2, + KPerBlock, 16, 16, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, - 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2, ADataType, ADataType, PermuteA, PermuteB>; // clang-format on @@ -247,9 +247,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) return true; } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950 and gfx12 only" << std::endl; return true; } diff --git a/example/01_gemm/gemm_xdl_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp8_v3.cpp index da891267b2..a9e39256ba 100644 --- a/example/01_gemm/gemm_xdl_fp8_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_v3.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -36,7 +36,7 @@ using DeviceGemmV2Instance = 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 2, S<1, 32, 1, 8>, 8, + 1, 2, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; // clang-format on diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 3237f1a61c..57702e2178 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp new file mode 100644 index 0000000000..9b92fad779 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| +// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| +// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type | +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| +// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, + 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index d149fd88f1..d5c42558c4 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; #else - < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; using ADataType = float; using BDataType = float; using CDataType = float; @@ -185,7 +185,6 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -209,8 +208,7 @@ int main(int argc, char* argv[]) return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index d8672f6a0c..76a30657f0 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -29,7 +29,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_WaveletM // ######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 4>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3e018aad1e..7fb0c1e812 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -2,7 +2,11 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/validation_common.hpp" + +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) @@ -24,11 +28,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -54,17 +58,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - try - { - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - } - catch(const std::runtime_error& e) - { - std::cerr << "Error: " << e.what() << std::endl; - return false; - } - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); @@ -218,8 +211,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -249,8 +242,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; diff --git a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp index abf7ef3905..1049b5d07c 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_xdl_fp16.cpp @@ -87,10 +87,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -108,7 +108,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; int main(int argc, char* argv[]) { diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index dffeff2337..992e7c19c8 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,10 +83,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -104,7 +104,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; int main(int argc, char* argv[]) { @@ -113,13 +113,13 @@ int main(int argc, char* argv[]) bool time_kernel = false; // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t M = 1920; + ck::index_t N = 2048; + ck::index_t K = 2048; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideE = 4096; + ck::index_t StrideA = 2048; + ck::index_t StrideB = 2048; + ck::index_t StrideE = 2048; if(argc == 1) { @@ -174,6 +174,9 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + const auto StrideD = std::is_same::value + ? d_m_n.mDesc.GetStrides()[0] + : d_m_n.mDesc.GetStrides()[1]; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; @@ -221,7 +224,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{0}, + std::array{static_cast(StrideD)}, StrideE, a_element_op, b_element_op, diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index e630f67837..4e98bf3034 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,7 +32,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm int { + if constexpr(std::is_same_v) + { + return static_cast(tensor.GetStrides()[0]); + } + else + { + return static_cast(tensor.GetStrides()[1]); + } + }; + + if(StrideA <= 0) + StrideA = fetch_leading_stride(a_m_k, ALayout{}); + if(StrideB <= 0) + StrideB = fetch_leading_stride(b_k_n, BLayout{}); + if(StrideD0 <= 0) + StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{}); + if(StrideD1 <= 0) + StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{}); + if(StrideE <= 0) + StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{}); + switch(config.init_method) { case 0: break; diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 91c072aef7..4f174bfcbb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -19,4 +19,13 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() + +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index b0fd6a382a..d82b56ec00 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -27,10 +27,14 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 5e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-2; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -116,7 +124,8 @@ template + typename DeviceConvNDFwdInstance, + typename ComputeDataType = OutDataType> bool run_grouped_conv_fwd(bool do_verification, int init_method, bool time_kernel, @@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, @@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification, return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp index b6bb03e1e5..6b66ebbdec 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -72,7 +72,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_convnd_fwd_example.inc" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp index 0fc9e7b5dd..d270d446b5 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -73,9 +73,17 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, ComputeType>; #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp index 9eba00993a..21bfd71a69 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -53,10 +53,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -74,10 +74,18 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeType, BComputeType>; #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 064a971478..7db7fdf4a8 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -72,7 +72,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_convnd_fwd_example.inc" diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp index 346ab8d953..62040384ad 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -73,9 +73,17 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, ComputeType>; #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // fp8 are not supported on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index 36517e569d..40c38b39d8 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -76,4 +76,11 @@ using DeviceGroupedConvNDFwdInstance = #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..348da7e1ef --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = float; +using ComputeDataType = ck::tf32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 192, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 3, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CDEBlockTransferScalarPerVector_NPerBlock + ComputeDataType, // AComputeDataType + ComputeDataType, // BComputeDataType + ck::LoopScheduler::Default, // LoopScheduler + 1 // NumGroupsToMerge + >; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp index ef130148bc..c635d01d8f 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -7,6 +7,8 @@ #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using InDataType = ck::f8_t; using WeiDataType = ck::f8_t; using AccDataType = float; @@ -52,10 +54,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -73,9 +75,19 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, ComputeDataType>; #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp index 53a12377c5..de6350db88 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -53,10 +53,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -74,10 +74,18 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeType, BComputeType>; #include "run_convnd_fwd_example.inc" -int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_convnd_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 0180e6e718..4ed47d2cae 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_common.hpp" @@ -51,10 +51,10 @@ using DeviceGroupedConvNDFwdInstance = 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -71,8 +71,8 @@ using DeviceGroupedConvNDFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 64, 1, 4>, - 16>; + S<1, 32, 1, 8>, + 4>; #include "run_convnd_fwd_example.inc" diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 49852ff667..016a189d4b 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + bool run_convnd_fwd_example(int argc, char* argv[]) { print_helper_msg(); @@ -65,17 +70,17 @@ bool run_convnd_fwd_example(int argc, char* argv[]) InElementOp, WeiElementOp, OutElementOp, - DeviceGroupedConvNDFwdInstance>( - do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + DeviceGroupedConvNDFwdInstance, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); }; namespace ctc = ck::tensor_layout::convolution; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 036f288d0a..7142521c55 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); problem_size = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp index 5848785673..c1ee36ef99 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/convnd_fwd_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -15,4 +15,11 @@ using RsDataType = ck::Tuple; #include "run_convnd_fwd_max_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_max_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return !run_convnd_fwd_max_example(argc, argv); +} diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index d61aee81a4..4b290d02a2 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -23,7 +23,7 @@ using RsGlobalReduceOp = static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off template @@ -36,7 +36,7 @@ using DeviceInstance = #ifdef BUILD_INT4_EXAMPLE < NDimSpatial, ALayout, BLayout, DELayout, RLayout, KernelADataType, KernelBDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; #else - < NDimSpatial, ALayout, BLayout, DELayout, RLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; + < NDimSpatial, ALayout, BLayout, DELayout, RLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, ConvSpec, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; #endif template diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index eb8b5c76d3..9e125c4e5d 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -100,13 +100,13 @@ int main(int argc, char* argv[]) const std::array reduceDims = {3, 4}; // const std::array invariantDims = {0, 1, 2}; - const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + std::vector inLengths_1 = {64, 320, 80, 4, 128}; // input lengths of the second reduction, which is also the output lengths of the first // reduction - const std::vector inLengths_2 = {64, 320, 80, 4}; + std::vector inLengths_2 = {64, 320, 80, 4}; - const std::vector outLengths = {64, 320, 80}; + std::vector outLengths = {64, 320, 80}; if(argc == 1) { @@ -114,11 +114,26 @@ int main(int argc, char* argv[]) init_method = 2; time_kernel = true; } - else if(argc == 4) + else if((argc == 4) || (argc == 9)) { do_verify = static_cast(argv[1]); init_method = atoi(argv[2]); time_kernel = static_cast(atoi(argv[3])); + if(argc == 9) + { + inLengths_1[0] = atoi(argv[4]); + inLengths_1[1] = atoi(argv[5]); + inLengths_1[2] = atoi(argv[6]); + inLengths_1[3] = atoi(argv[7]); + inLengths_1[4] = atoi(argv[8]); + inLengths_2[0] = inLengths_1[0]; + inLengths_2[1] = inLengths_1[1]; + inLengths_2[2] = inLengths_1[2]; + inLengths_2[3] = inLengths_1[3]; + outLengths[0] = inLengths_1[0]; + outLengths[1] = inLengths_1[1]; + outLengths[2] = inLengths_1[2]; + } } else { diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 3ce08fd2af..abbf1b29f7 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -78,12 +78,12 @@ bool pool_test(bool do_verification, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 8703fa3ed7..b058e7b0fa 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 2585072dfe..5291f5ce69 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -115,12 +115,14 @@ int main() if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1_uz})); + std::vector({stride, 1_uz}), + layout); } else { return HostTensorDescriptor(std::vector({row, col}), - std::vector({1_uz, stride})); + std::vector({1_uz, stride}), + layout); } }; diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp new file mode 100644 index 0000000000..a3023997a1 --- /dev/null +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Col; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + ActivationOp, + ActivationOp, + CDEElementOp, + GemmDefault, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + S<1>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + I8, + I8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int /* argc */, char* /* argv */[]) +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // device GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp index aa3e011695..8f68ac6b05 100644 --- a/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_xdl_bias_relu_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -70,10 +70,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl 64, // KPerBlock, 16, // AK1, 16, // BK1, - 32, // MPerXDL, - 32, // NPerXDL, - 4, // MXdlPerWave, - 2, // NXdlPerWave, + 16, // MPerXDL, + 16, // NPerXDL, + 8, // MXdlPerWave, + 4, // NXdlPerWave, S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, S<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -90,8 +90,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl 1, // bool BBlockLdsExtraN, 1, // index_t CShuffleMXdlPerWavePerShuffle, 1, // index_t CShuffleNXdlPerWavePerShuffle, - S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm #include @@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl 64, // KPerBlock, 16, // AK1, 16, // BK1, - 32, // MPerXDL, - 32, // NPerXDL, - 4, // MXdlPerWave, - 2, // NXdlPerWave, + 16, // MPerXDL, + 16, // NPerXDL, + 8, // MXdlPerWave, + 4, // NXdlPerWave, S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1, S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, S<1, 0, 2>, // ABlockTransferSrcAccessOrder, @@ -88,8 +88,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl 1, // bool BBlockLdsExtraN, 1, // index_t CShuffleMXdlPerWavePerShuffle, 1, // index_t CShuffleNXdlPerWavePerShuffle, - S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp index 63a2aea0b3..c8de51f550 100644 --- a/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -63,7 +63,7 @@ using DeviceGemmInstance = //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<4,4,4>>; // clang-format on struct ProblemSize final diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp index 680cee1f81..ac64a468a4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index 5bdc993192..2fcc0e3cb1 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -323,6 +323,31 @@ int main(int argc, char* argv[]) problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ns.push_back(768); @@ -333,21 +358,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 6806bd1886..fb611fd444 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -296,6 +296,32 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(128 + rand() % 128); @@ -307,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 8418c10f5e..47eb6637bd 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -297,6 +297,31 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -308,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 90a12bc1dd..85ea8c2f2c 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index 28b0fcd0ce..fb047ae364 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp index 0c96ef56d3..6b8de05f73 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 9f8f6cb1e4..16d018936b 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,6 +66,28 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.group_count = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -77,19 +99,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 7186c22233..4ef6074f4a 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -278,6 +278,30 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -288,27 +312,6 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - } - else if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.async_hargs = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: async hargs (0=n0, 1=yes)\n"); - exit(0); - } return run_grouped_gemm(problem_size, config); } diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index a46eaa4816..3cc38b381b 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,7 +76,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | - < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -92,7 +92,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp index b30ce2c48a..0290c1829d 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -65,10 +65,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -85,7 +85,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp index 31e2efd6f6..e211a63b0b 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -65,10 +65,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -85,7 +85,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp index d3c7c1d99c..90c2cdcdaa 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_max_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -146,6 +146,11 @@ int main(int argc, char* argv[]) exit(0); } + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return run_gemm_reduce_max_xdl, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -84,7 +84,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp index 5c2706c79a..3ee3037179 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -72,10 +72,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -92,7 +92,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp index c119e24370..9ce1e76cf5 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -72,10 +72,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransfer ThreadCluster Lengths_K0_M_K1 S<1, 0, 2>, // ABlockTransfer ThreadCluster ArrangeOrder S<1, 0, 2>, // ABlockTransfer SrcAccessOrder @@ -92,7 +92,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip 1, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle - S<64, 4>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock + S<32, 8>, // CD Reduce Thread Transfer ClusterLengths _MPerBlock_NPerBlock 4, // CDE ReduceThreadTransfer ScalarPerVector _NPerBlock 1>; // RThread DstScalarPerVector _MPerBlock // clang-format on diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp index 0f5e588383..7815d2beea 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_mean_meansquare_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gemm_reduce_xdl_common.hpp" @@ -153,6 +153,11 @@ int main(int argc, char* argv[]) exit(EXIT_SUCCESS); } + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + exit(EXIT_SUCCESS); + } + return !run_gemm_reduce_mean_meansquare_xdl #include @@ -64,7 +64,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; + < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, S<32, 8>, 4, 1>; // clang-format on using ReferenceBatchedGemmInstance = @@ -137,11 +137,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); } }; diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index e0034bf7eb..9159e51eaf 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -123,7 +123,9 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, + threshold_to_catch_partial_args + 1, // +1 because we already parsed num_dim_spatial + argv); } else { diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 80b9724930..b19aff2081 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,10 +44,10 @@ using DeviceConvBwdWeightInstance = 128, // NPerBlock 4, // K0PerBlock 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 4, // NXdlPerWave S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder @@ -80,6 +80,11 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe int main(int argc, char* argv[]) { + if(ck::is_gfx11_supported()) + { + return 0; + } + ExecutionConfig config; ck::utils::conv::ConvParam conv_param = DefaultConvParam; diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 5dccb11bba..abbc7a946c 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -80,7 +80,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | - < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm(std::stoi(argv[2])); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); @@ -357,6 +376,7 @@ int main() normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false}); bool pass = true; + if(do_verification) { // verification Tensor host_layerNorm_m_n( @@ -383,27 +403,25 @@ int main() 1e-2); } + if(time_kernel) { // evaluate kernel perf - bool time_kernel = true; - float gemm_reduce_mean_reduce_square_mean_ave_time = - gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, true}); float normalize_ave_time = - normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, true}); - if(time_kernel) - DumpGemmLayerNormPerf( - gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); } return pass ? 0 : 1; diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index 6a92e9a2f5..ae5e3f36ad 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>; // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -154,6 +154,12 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + bool do_verification = true; // GEMM shape diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index 168193ad5b..23c602c39e 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,7 +77,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDMultip //######| | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| DstScalarPerVector| //######| | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| | - < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, 1>; + < ALayout, BLayout, ELayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, EDataType, ReduceAccDataType, RsDataType, AElementOp, BElementOp, CDEElementOp, QsElementOp, RsElementOp, RsThreadReduceOp, RsGlobalReduceOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm(std::stoi(argv[2])); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); @@ -333,6 +352,7 @@ int main() normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, false}); bool pass = true; + if(do_verification) { // verification Tensor host_layerNorm_m_n( @@ -354,25 +374,23 @@ int main() layerNorm_m_n, host_layerNorm_m_n, "Error: Incorrect results d1", 1e-3, 1e-3); } + if(time_kernel) { // evaluate kernel perf - bool time_kernel = true; - float gemm_reduce_mean_reduce_square_mean_ave_time = - gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, true}); float normalize_ave_time = - normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + normalize_invoker.Run(normalize_argument_ptr.get(), StreamConfig{nullptr, true}); - if(time_kernel) - DumpGemmLayerNormPerf( - gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); } return pass ? 0 : 1; diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index 277fea0272..10d90b795c 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -70,7 +70,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmLayerNorm_Xdl //######| | | | Type| Type| Type| Type| DataType| DataType| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, S<64, 4>, 4>; + < Row, Col, Row, ADataType, BDataType, CDataType, C0DataType, AccDataType, CShuffleDataType, AccDataType, AElementOp, BElementOp, AccElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 4, S<1, 32, 1, 8>, 8, S<32, 8>, 4>; // clang-format on using ReferenceInstance = ck::tensor_operation::host::ReferenceGemmLayernorm @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -69,7 +69,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 89a581e865..2996d87b28 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -47,10 +47,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -68,7 +68,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index cf96599599..45d23b4377 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 16, // index_t KPerBlock 4, // index_t AK1 4, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -69,11 +69,16 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 16, 1, 16>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 2>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + bool do_verification = true; int init_method = 1; bool time_kernel = false; @@ -87,25 +92,25 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } } else { @@ -114,7 +119,7 @@ int main(int argc, char* argv[]) << "arg3: run kernel # of times (>1)\n" << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" << std::endl; - exit(0); + exit(1); } return !run_cgemm_xdl @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 64, // index_t KPerBlock 16, // index_t AK1 16, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -68,8 +68,8 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t BBlockLdsExtraN 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle - S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) @@ -87,25 +87,25 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } } else { @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) << "arg3: run kernel # of times (>1)\n" << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" << std::endl; - exit(0); + exit(1); } return !run_cgemm_xdl #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index 548500518f..bd4be91103 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -89,11 +89,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8>, // CDEShuffleBlockTransferScalarPerVectors + S<4>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion >; #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp index d1985f9af5..80f6b1d663 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp index 42171bcdb7..46e452005d 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp index a92a04dbe6..19a63ff841 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -50,9 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + + return run_batched_gemm_example(argc, argv); +} diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp index 84f92eba8e..2a4610bf90 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -74,10 +74,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -95,7 +95,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors + S<4, 4, 1>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion F8 // ComputeTypeA @@ -103,4 +103,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #include "run_batched_gemm_example_rowwise.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_rowwise_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp index 5e82cfe324..4adb79ed29 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -96,4 +98,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #define BUILD_INT4_EXAMPLE #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp index ad22227af5..06b8c3c0cd 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -48,9 +50,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 741512bf00..182ab8d967 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once @@ -59,11 +61,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; @@ -214,35 +218,37 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.batch_count = 2; - if(argc == 4) + if(argc == 1) + { + // use default case + } + else if(argc == 4 || argc == 8) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - } - else if(argc == 8) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - problem_size.M = std::stoi(argv[4]); - problem_size.N = std::stoi(argv[5]); - problem_size.K = std::stoi(argv[6]); - problem_size.batch_count = std::stoi(argv[7]); + if(argc == 8) + { + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("optinal\n"); - printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", - problem_size.M, - problem_size.N, - problem_size.K, - problem_size.batch_count); - exit(0); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("optional\n"); + printf("arg4-7: M, N, K, Batch\n"); + exit(1); } + printf("M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); problem_size.stride_A = problem_size.K; problem_size.stride_B = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 3582bc5e33..5e56670fcf 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -137,11 +137,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; @@ -344,7 +346,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - return true; + return false; } bool pass = true; @@ -521,6 +523,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 1; + } + ProblemSize problem_size; ExecutionConfig config; @@ -533,30 +540,30 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) problem_size.batch_count = 2; - if(argc == 4) + if(argc == 1) { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc >= 7) + else if(argc == 4 || argc >= 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - - problem_size.M = std::stoi(argv[4]); - problem_size.N = std::stoi(argv[5]); - problem_size.K = std::stoi(argv[6]); - - if(argc >= 8) + if(argc >= 7) { - problem_size.batch_count = std::stoi(argv[7]); - } + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); - if(argc >= 9) - { - problem_size.KBatch = std::stoi(argv[8]); + if(argc >= 8) + { + problem_size.batch_count = std::stoi(argv[7]); + } + + if(argc >= 9) + { + problem_size.KBatch = std::stoi(argv[8]); + } } } else @@ -564,7 +571,10 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); + printf("arg4-6: problem size (M, N, K)\n"); + printf("arg7: batch count\n"); + printf("arg8: KBatch\n"); + exit(1); } problem_size.stride_A = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 778be8ffd7..6ed0b23407 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once @@ -64,11 +64,13 @@ bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionCo if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index 420a7cf74f..4f4003809b 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1 using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -189,7 +191,7 @@ int run_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -173,7 +175,7 @@ int run_contraction_scale_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 #include @@ -18,6 +18,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceGroupedContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on // hardcoded for NumDimM == NumDimN == NumDimK == 2 @@ -194,22 +197,28 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = false; - if(argc == 4) + std::size_t group_count = rand() % 16 + 1; + + if(argc == 1) + { + // use default + } + else if(argc == 5) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + group_count = std::stoi(argv[4]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: group count (default = random from 1..16)"); exit(0); } - std::size_t group_count = rand() % 16 + 1; - // GEMM shape std::vector> contraction_descs; std::vector p_a, p_b; @@ -298,10 +307,10 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Bypass{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); ck::index_t M_ = ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); @@ -410,9 +419,9 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index f556be887f..c4cb7a13a2 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -300,11 +303,11 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; @@ -396,7 +399,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -54,7 +57,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -345,7 +348,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -128,7 +128,7 @@ using DeviceConvFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 16, 1, 16>, + S<1, 32, 1, 8>, 4>; template diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc index da65bb1886..c661871dfa 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_example.inc @@ -27,10 +27,10 @@ using DeviceConvFwdInstance = 16, // KPerBlock 4, // AK1 4, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -47,7 +47,7 @@ using DeviceConvFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 16, 1, 16>, + S<1, 32, 1, 8>, 4>; template diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 11933f09a9..811b133b44 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -11,6 +11,6 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) -if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95" AND NOT GPU_TARGETS MATCHES "gfx1") +if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx95") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) endif() diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc index c35b6be53e..de4f5f09e7 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_wmma_cshuffle_v3_base.inc @@ -261,6 +261,10 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm< B1ElementOp, CElementOp>; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + #include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc" int main(int argc, char* argv[]) diff --git a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp index 7605d9c4f8..9afd199f24 100644 --- a/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp +++ b/example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o @@ -84,11 +84,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -113,7 +113,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -113,7 +113,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -110,7 +110,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_X 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; @@ -270,7 +272,18 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); #endif - return ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); + if constexpr(ck::is_same_v) + { + return ck::utils::check_err(c_g_m_o_device_result, + c_g_m_o_host_result, + "Error: Incorrect results!", + 1e-3, + 1.1e-3); + } + else + { + return ck::utils::check_err(c_g_m_o_device_result, c_g_m_o_host_result); + } } return true; diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc index 8ab47c2925..cea18459f4 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc @@ -62,17 +62,19 @@ int run(int argc, char* argv[]) std::vector b1_g_o_n_lengths{G, O, N}; #ifdef CK_MHA_USE_RCCR_LAYOUT std::vector b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N] + auto b1_layout = Row{}; #else std::vector b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O] + auto b1_layout = Col{}; #endif std::vector c_g_m_o_lengths{G, M, O}; std::vector c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O] - Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides); - Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides); - Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides); - Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides); - Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides); + Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides, Row{}); + Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides, Row{}); + Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides, b1_layout); + Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); + Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 1d1566d575..2604a50a76 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -101,11 +101,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -130,7 +130,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp index bae88d4b8e..331bfe99c2 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -100,11 +100,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -129,7 +129,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: bf16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index a098ce6675..cd321c0da3 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -101,11 +101,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -130,7 +130,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp index ce8caf7588..f30ec3fd03 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -84,11 +84,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -113,7 +113,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock false>; // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index 138db14963..e403ba7f66 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -85,11 +85,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -114,7 +114,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock false>; // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 5794924294..7738a6b6d4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -100,11 +100,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -129,7 +129,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 97caec6053..b59498829e 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -101,11 +101,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -130,7 +130,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 1514fc48b3..aa2a6b3b42 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -111,12 +111,14 @@ int run(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); + std::vector({batch_stride, stride, 1}), + layout); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); + std::vector({batch_stride, 1, stride}), + layout); } }; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 2b02069e65..6175f0b5be 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +90,11 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Bypass{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Bypass{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Bypass{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index e0ccb6dad1..db13e3b963 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +92,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 0ad031cc71..1e4b52d4cf 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -113,11 +117,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -191,7 +214,7 @@ int run(int argc, char* argv[]) head_num * 2 * head_dim, head_dim, 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] - Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides, Bypass{}); // merge kv into a packed pointer send to device b0_gs_ns_ks.ForEach( [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index c693995140..874d987a1d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -63,6 +67,19 @@ int run(int argc, char* argv[]) std::size_t flop = 0, num_byte = 0; + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; std::cout << "group count " << group_count << ". printing first 4 groups\n"; for(std::size_t i = 0; i < group_count; i++) { @@ -113,10 +130,14 @@ int run(int argc, char* argv[]) {}}); // acc1_biases_gs_ms_os_strides // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks(f_host_tensor_descriptor( + b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns(f_host_tensor_descriptor( + b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_device_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); int Batch = G0 * G1; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; @@ -252,7 +273,8 @@ int run(int argc, char* argv[]) Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_host_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); // permute a_gs_ms_ks.ForEach([&](auto& self, auto idx) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 7ac29f33ca..1c2a26d916 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index fb9b1b0bd7..76f3ee756c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 2cb69380e5..86754927ed 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -108,11 +112,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -186,7 +209,7 @@ int run(int argc, char* argv[]) head_num * 3 * head_dim, head_dim, 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] - Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides, Bypass{}); // merge qkv into a packed pointer send to device a_gs_ms_ks.ForEach( [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 904006ba36..e0476bfaad 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -27,3 +27,16 @@ add_example_executable(example_gemm_xdl_splitk_reduce_multi_d_bf16 gemm_xdl_spli add_example_executable(example_gemm_xdl_splitk_reduce_bf16A_i8B gemm_xdl_splitk_reduce_bf16A_i8B.cpp) add_example_executable(example_gemm_xdl_splitk_reduce_bfp16 gemm_xdl_splitk_reduce_bf16.cpp) + +add_custom_target(example_splitK_gemm_wmma) +add_example_executable(example_gemm_wmma_splitk_reduce_bf16 gemm_wmma_splitk_reduce_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_bf16A_i8B gemm_wmma_splitk_reduce_bf16A_i8B.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_bf16A_i8B) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_bf16 gemm_wmma_splitk_reduce_multi_d_bf16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_bf16) + +add_example_executable(example_gemm_wmma_splitk_reduce_multi_d_fp16 gemm_wmma_splitk_reduce_multi_d_fp16.cpp) +add_example_dependencies(example_splitK_gemm_wmma example_gemm_wmma_splitk_reduce_multi_d_fp16) diff --git a/example/35_splitK_gemm/common.hpp b/example/35_splitK_gemm/common.hpp index 64fadae9e5..325cc37731 100644 --- a/example/35_splitK_gemm/common.hpp +++ b/example/35_splitK_gemm/common.hpp @@ -99,3 +99,85 @@ bool parse_cmd_args(int argc, return true; } + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp new file mode 100644 index 0000000000..b481483d42 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = ck::bhalf_t; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp new file mode 100644 index 0000000000..dcf4a1652d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_bf16A_i8B.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = int8_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple<>; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = Row; +using DsLayout = ck::Tuple<>; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceWmmaGemmInstance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_example.inc" + +int main(int argc, char* argv[]) { return !run_wmma_gemm_splitk_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp new file mode 100644 index 0000000000..dab308d148 --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_bf16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using ReduceDataType = float; +using D0DataType = ck::bhalf_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 128, 32, + 8, 8, + 16, 16, + 4, 2, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp new file mode 100644 index 0000000000..489816559d --- /dev/null +++ b/example/35_splitK_gemm/gemm_wmma_splitk_reduce_multi_d_fp16.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; +using ReduceDataType = float; +using D0DataType = ck::half_t; +using DsDataType = ck::Tuple; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; +using D0Layout = CLayout; +using DsLayout = ck::Tuple; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Add; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3R1< + ALayout, BLayout, DsLayout, CLayout, + ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmDefault, + 256, + 128, 256, 64, + 8, 8, + 16, 16, + 4, 4, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 1, 8, true, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, ReduceDataType>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_wmma_splitk_reduce_multi_d_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_multi_d_example(argc, argv); } diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp index 7ceb1d09ef..1843198933 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp index b5aeff65d6..1e4398b9f6 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp index cb84f2a416..d5acde139a 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp index 2ab8f77dc4..bb3c23f060 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v2, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc index 9635993d63..0b060841bf 100644 --- a/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc +++ b/example/35_splitK_gemm/run_gemm_splitk_reduce_multi_d_example.inc @@ -3,88 +3,6 @@ #pragma once -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc new file mode 100644 index 0000000000..25628ef770 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_example.inc @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device GEMM + auto device_op = DeviceWmmaGemmInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto argument = + device_op.MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, // empty D tensors + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, // empty D strides + StrideC, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + // Allocate workspace for split-K reduction if needed + size_t workspace_size = device_op.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_buf(workspace_size); + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_buf.GetDeviceBuffer(); + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + } + + if(!device_op.IsSupportedArgument(argument.get())) + { + std::cout << "The runtime argument is not supported!" << std::endl; + std::cout << "Debug info:" << std::endl; + std::cout << " M=" << M << ", N=" << N << ", K=" << K << ", KBatch=" << KBatch + << std::endl; + std::cout << " StrideA=" << StrideA << ", StrideB=" << StrideB << ", StrideC=" << StrideC + << std::endl; + return false; + } + + bool pass = true; + float ave_time = 0; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, false}); + + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E12 / ave_time; + + float gb_per_sec = num_btype / 1.E9 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << device_op.GetTypeString() << std::endl; + } + + return pass; +} + +bool run_wmma_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc new file mode 100644 index 0000000000..59996655c6 --- /dev/null +++ b/example/35_splitK_gemm/run_gemm_wmma_splitk_reduce_multi_d_example.inc @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_wmma_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto StrideD0 = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + + Tensor a_m_k( + f_host_tensor_descriptor(problem_size.M, problem_size.K, problem_size.StrideA, ALayout{})); + Tensor b_k_n( + f_host_tensor_descriptor(problem_size.K, problem_size.N, problem_size.StrideB, BLayout{})); + Tensor d0_m_n( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, D0Layout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + Tensor c_m_n_host_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(problem_size.M, problem_size.N, problem_size.StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << KBatch << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + constexpr auto kNum_DTensors = DsDataType::Size(); + const std::array p_ds = {d0_m_n_device_buf.GetDeviceBuffer()}; + const std::array d_strides = {problem_size.StrideC}; + + auto argument = + gemm.MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + p_ds, + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.StrideA, + problem_size.StrideB, + d_strides, + problem_size.StrideC, + problem_size.KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument.get())) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + return false; + } + + auto workspace_size = gemm.GetWorkSpaceSize(argument.get()); + DeviceMem workspace_device_buf(workspace_size); + + std::cout << "Workspace size: " << workspace_size << " bytes" << std::endl; + std::cout << "Allocated workspace of size: " << workspace_size << " bytes" << std::endl; + + if(workspace_size > 0) + { + argument->p_workspace_ = workspace_device_buf.GetDeviceBuffer(); + } + + if(config.do_verification) + { + using ReferenceGemmInstanceMultiD = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstanceMultiD{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + c_m_n_host_result.ForEach( + [&](auto& self, auto idx) { c_element_op(self(idx), self(idx), d0_m_n(idx)); }); + } + + std::cout << "init method: " << config.init_method << std::endl; + std::cout << "KBatch: " << problem_size.KBatch << std::endl; + + float ave_time = invoker.Run(argument.get(), StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = std::size_t(2) * problem_size.M * problem_size.N * problem_size.K; + std::size_t num_btype = sizeof(ADataType) * problem_size.M * problem_size.K + + sizeof(BDataType) * problem_size.K * problem_size.N + + sizeof(CDataType) * problem_size.M * problem_size.N + + sizeof(D0DataType) * problem_size.M * problem_size.N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + double rtol = get_rtol(); + double atol = get_atol(); + + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", rtol, atol); + } + + return true; +} + +int run_gemm_splitk_multi_d_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_wmma_gemm(problem_size, config); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp index fdf49a31b7..1b8194f838 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -51,9 +51,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp index dc54bc30ef..8628e8770c 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -50,9 +50,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp index b93639e6c1..8091a5b448 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025 Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -57,4 +57,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp index 7506f69420..4257451754 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -55,4 +55,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp index 7ebf914408..f0d4e28ad2 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -89,4 +89,12 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu #define BUILD_INT4_EXAMPLE #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp index 6b0c1aa02d..d800443932 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -48,9 +48,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>; // clang-format on #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp index fc55019fc4..ef27c7bb9f 100644 --- a/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp +++ b/example/35_splitK_gemm/splitK_gemm_xdl_lds_direct_load_fp16.cpp @@ -79,4 +79,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlS #include "run_splitK_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return !run_splitK_gemm_example(argc, argv); +} diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index 26a03f289d..ff7acea48d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -50,14 +50,14 @@ template<> struct emb_kernel { using kernel_type = DeviceInsta // clang-format on -int main() +int main(int argc, char* argv[]) { bool time_kernel = true; - constexpr auto num_rows = 65536; - constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; - // constexpr auto dims = ck::Sequence<256, 512>{}; - constexpr auto index_length = 2048; + ck::index_t num_rows = 65536; + constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; + ck::index_t index_length = 2048; + ck::index_t dim_mask = 0xffff; constexpr AccDataType epsilon = 1e-4; auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); }; @@ -73,121 +73,143 @@ int main() BetaDataType, AccDataType, OutType>; + if(argc == 1) + { + // Use default value + } + else if(argc == 5) + { + time_kernel = std::stoi(argv[1]); + num_rows = std::stoi(argv[2]); + dim_mask = strtol(argv[3], nullptr, 0); + index_length = std::stoi(argv[4]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "arg1: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "arg2-4: num_rows dim_mask index_length" << std::endl; + return 1; + } ck::static_for<0, dims.Size(), 1>{}([&](auto I) { - std::srand(std::time(nullptr)); - constexpr auto current_dim = dims.At(I); - Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); - - Tensor index_a(f_host_tensor_desc_1d(index_length)); - Tensor index_b(f_host_tensor_desc_1d(index_length)); - Tensor index_c(f_host_tensor_desc_1d(index_length)); - - Tensor gamma(f_host_tensor_desc_1d(current_dim)); - Tensor beta(f_host_tensor_desc_1d(current_dim)); - - Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); - - emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); - DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); - DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); - - DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); - DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); - DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); - - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - - DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); - - emb_a_dev.ToDevice(emb_a.mData.data()); - emb_b_dev.ToDevice(emb_b.mData.data()); - emb_c_dev.ToDevice(emb_c.mData.data()); - - index_a_dev.ToDevice(index_a.mData.data()); - index_b_dev.ToDevice(index_b.mData.data()); - index_c_dev.ToDevice(index_c.mData.data()); - - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = typename emb_kernel::kernel_type{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - out_dev.GetDeviceBuffer(), - {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, - {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - current_dim, - index_length, - epsilon, - EmbElementwiseOperation{}); - std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() - << std::endl - << std::flush; - - bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); - - if(!is_supported) + if(dim_mask & (1 << I.value)) { - std::cout << "Runtime parameters are not supported" << std::endl; - return; + std::srand(std::time(nullptr)); + constexpr auto current_dim = dims.At(I); + Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); + + Tensor index_a(f_host_tensor_desc_1d(index_length)); + Tensor index_b(f_host_tensor_desc_1d(index_length)); + Tensor index_c(f_host_tensor_desc_1d(index_length)); + + Tensor gamma(f_host_tensor_desc_1d(current_dim)); + Tensor beta(f_host_tensor_desc_1d(current_dim)); + + Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); + + emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); + DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); + DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); + + DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); + DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); + DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); + + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + + DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); + + emb_a_dev.ToDevice(emb_a.mData.data()); + emb_b_dev.ToDevice(emb_b.mData.data()); + emb_c_dev.ToDevice(emb_c.mData.data()); + + index_a_dev.ToDevice(index_a.mData.data()); + index_b_dev.ToDevice(index_b.mData.data()); + index_c_dev.ToDevice(index_c.mData.data()); + + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = typename emb_kernel::kernel_type{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + out_dev.GetDeviceBuffer(), + {ck::type_convert(emb_a_dev.GetDeviceBuffer()), + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + {ck::type_convert(index_a_dev.GetDeviceBuffer()), + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + current_dim, + index_length, + epsilon, + EmbElementwiseOperation{}); + std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() + << std::endl + << std::flush; + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cerr << device_instance.GetTypeString() << " does not support this problem" + << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float time_ms = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(out, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + num_rows, + current_dim, + index_length, + epsilon); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + out_dev.FromDevice(out_from_dev.mData.data()); + pass &= + ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); + } + + double total_read = current_dim * index_length * 3 * sizeof(EmbType) + + current_dim * sizeof(GammaDataType) + + current_dim * sizeof(BetaDataType); + double total_write = current_dim * index_length * sizeof(OutType); + double gbps = (total_read + total_write) / time_ms / 1e6; + + std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms + << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl + << std::flush; } - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); - ReferenceInstance ref; - auto ref_argument = ref.MakeArgument(out, - emb_a, - emb_b, - emb_c, - index_a, - index_b, - index_c, - gamma, - beta, - num_rows, - current_dim, - index_length, - epsilon); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - out_dev.FromDevice(out_from_dev.mData.data()); - pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); - } - - double total_read = current_dim * index_length * 3 * sizeof(EmbType) + - current_dim * sizeof(GammaDataType) + - current_dim * sizeof(BetaDataType); - double total_write = current_dim * index_length * sizeof(OutType); - double gbps = (total_read + total_write) / time_ms / 1e6; - - std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms - << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl - << std::flush; }); return 0; diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index f27dc60541..4934f74393 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o] @@ -154,11 +154,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -185,7 +185,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock int main(int argc, char* argv[]) { @@ -321,11 +321,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp index 4c28e25e01..a377685e52 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" @@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; + < NDimSpatial, OutLayout, WeiLayout, BiasLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, ck::Tuple, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_conv_bwd_data_bias_relu_example.inc" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp index b1554412b1..59d94c34bb 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" @@ -26,7 +26,7 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>; + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_conv_bwd_data_example.inc" diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp index 41023ef82a..d49fb9befb 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" #include "common.hpp" @@ -30,9 +30,17 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat // ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, AComputeType, BComputeType>; + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 16, 16, 4, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, AComputeType, BComputeType>; // clang-format on #include "run_grouped_conv_bwd_data_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0f0b120cbc..80d56cd781 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -206,7 +206,8 @@ int run_grouped_conv_bwd_data_bias_relu_example(int argc, char* argv[]) 1, // c 0, // hi 0 // wi - }); + }, + ctc::GNCHW{}); // input image: GNHWC const auto in_g_n_c_wis_desc = diff --git a/example/39_permute/permute_1xHxW_fp16.cpp b/example/39_permute/permute_1xHxW_fp16.cpp index 7336c3b631..30cf4ef083 100644 --- a/example/39_permute/permute_1xHxW_fp16.cpp +++ b/example/39_permute/permute_1xHxW_fp16.cpp @@ -17,4 +17,23 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl #include "run_permute_element_example.inc" -int main() { return !run_permute_element_example({1, 32000, 80}, {0, 2, 1}); } +int main(int argc, char* argv[]) +{ + bool time_kernel = false; + + if(argc == 1) + { + // use default + } + else if(argc == 2) + { + time_kernel = std::stoi(argv[1]); + } + else + { + printf("arg1: time kernel (0=no, 1=yes, default=0)\n"); + exit(0); + } + + return !run_permute_element_example({1, 32000, 80}, {0, 2, 1}, time_kernel); +} diff --git a/example/39_permute/permute_HxWx4_fp16.cpp b/example/39_permute/permute_HxWx4_fp16.cpp index 6c24919ded..c655384301 100644 --- a/example/39_permute/permute_HxWx4_fp16.cpp +++ b/example/39_permute/permute_HxWx4_fp16.cpp @@ -19,4 +19,23 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl #include "run_permute_bundle_example.inc" -int main() { return !run_permute_bundle_example({1, 80, 32000}, {0, 2, 1}); } +int main(int argc, char* argv[]) +{ + bool time_kernel = false; + + if(argc == 1) + { + // use default + } + else if(argc == 2) + { + time_kernel = std::stoi(argv[1]); + } + else + { + printf("arg1: time kernel (0=no, 1=yes, default=0)\n"); + exit(0); + } + + return !run_permute_bundle_example({1, 80, 32000}, {0, 2, 1}, time_kernel); +} diff --git a/example/39_permute/permute_NxHxW_fp16.cpp b/example/39_permute/permute_NxHxW_fp16.cpp index 3551d2a7c8..d3d7f47ced 100644 --- a/example/39_permute/permute_NxHxW_fp16.cpp +++ b/example/39_permute/permute_NxHxW_fp16.cpp @@ -17,4 +17,23 @@ using DevicePermuteInstance = ck::tensor_operation::device::DevicePermuteImpl #include "run_permute_element_example.inc" -int main() { return !run_permute_element_example({121, 768, 80}, {0, 2, 1}); } +int main(int argc, char* argv[]) +{ + bool time_kernel = false; + + if(argc == 1) + { + // use default + } + else if(argc == 2) + { + time_kernel = std::stoi(argv[1]); + } + else + { + printf("arg1: time kernel (0=no, 1=yes, default=0)\n"); + exit(0); + } + + return !run_permute_element_example({121, 768, 80}, {0, 2, 1}, time_kernel); +} diff --git a/example/39_permute/run_permute_bundle_example.inc b/example/39_permute/run_permute_bundle_example.inc index 2c19872922..fab02f8cf3 100644 --- a/example/39_permute/run_permute_bundle_example.inc +++ b/example/39_permute/run_permute_bundle_example.inc @@ -3,7 +3,7 @@ #pragma once -bool run_permute_bundle(const Problem& problem) +bool run_permute_bundle(const Problem& problem, bool time_kernel) { const auto& input_bundle_shape = problem.shape; const auto& input_bundle_axes = problem.axes; @@ -41,7 +41,7 @@ bool run_permute_bundle(const Problem& problem) }; auto invoker = permute.MakeInvoker(); - float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::cout << "Perf: " << ave_time << " ms" << std::endl; @@ -72,7 +72,9 @@ bool run_permute_bundle(const Problem& problem) 1e-6); } -bool run_permute_bundle_example(const Problem::Shape& shape, const Problem::Axes& axes) +bool run_permute_bundle_example(const Problem::Shape& shape, + const Problem::Axes& axes, + bool time_kernel) { - return run_permute_bundle(Problem{shape, axes}); + return run_permute_bundle(Problem{shape, axes}, time_kernel); } diff --git a/example/39_permute/run_permute_element_example.inc b/example/39_permute/run_permute_element_example.inc index 3587134456..c3f3b972e9 100644 --- a/example/39_permute/run_permute_element_example.inc +++ b/example/39_permute/run_permute_element_example.inc @@ -3,7 +3,7 @@ #pragma once -bool run_permute_element(const Problem& problem) +bool run_permute_element(const Problem& problem, bool time_kernel) { const auto& input_shape = problem.shape; const auto& input_axes = problem.axes; @@ -40,7 +40,7 @@ bool run_permute_element(const Problem& problem) }; auto invoker = permute.MakeInvoker(); - float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::cout << "Perf: " << ave_time << " ms" << std::endl; @@ -59,7 +59,9 @@ bool run_permute_element(const Problem& problem) 1e-6); } -bool run_permute_element_example(const Problem::Shape& shape, const Problem::Axes& axes) +bool run_permute_element_example(const Problem::Shape& shape, + const Problem::Axes& axes, + bool time_kernel) { - return run_permute_element(Problem{shape, axes}); + return run_permute_element(Problem{shape, axes}, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp index 4573c68658..f9a7d9f638 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -78,8 +78,28 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + const auto out_element_op = OutElementOp{ActivationOp{}}; - run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); + run_conv2d_fwd_bias_perchannel_quantization_example( + out_element_op, do_verification, time_kernel); }; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp index 005f6263fd..333987edd6 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -76,9 +76,28 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float requant_scale = 0.5f; const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; - run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp index 62e5e583de..4b94045421 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -79,9 +79,29 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float scale_z_inv = 0.5f; const auto out_element_op = OutElementOp{scale_z_inv, ActivationOp{}}; - run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); + run_conv2d_fwd_bias_perchannel_quantization_example( + out_element_op, do_verification, time_kernel); }; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp index ef98fe7e4f..b74e06b10a 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -76,10 +76,29 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float scale_acc = 0.5f; float scale_z_inv = 0.5f; const auto out_element_op = OutElementOp{scale_z_inv, scale_acc, ActivationOp{}}; - run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp index e524ddb2b2..c3ac40a1bc 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -76,8 +76,27 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perchannel_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + const auto out_element_op = OutElementOp{ActivationOp{}}; - run_conv2d_fwd_perchannel_quantization_example(out_element_op); + run_conv2d_fwd_perchannel_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp index d29a3143c0..437fd6f4c2 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp @@ -71,9 +71,28 @@ using DeviceGroupedConvNDFwdInstance = #include "run_conv2d_fwd_perlayer_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = false; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float requant_scale = 0.5f; const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; - run_conv2d_fwd_perlayer_quantization_example(out_element_op); + run_conv2d_fwd_perlayer_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp index 8c0049b0fa..d9cfae2898 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" @@ -57,10 +57,10 @@ using DeviceGroupedConvNDFwdInstance = 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -77,13 +77,33 @@ using DeviceGroupedConvNDFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 64, 1, 4>, - 8>; + S<1, 32, 1, 8>, + 4>; #include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + const auto out_element_op = OutElementOp{ActivationOp{}}; - run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); + run_conv2d_fwd_bias_perchannel_quantization_example( + out_element_op, do_verification, time_kernel); }; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp index e18c123f7c..9d3024fce7 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" @@ -55,10 +55,10 @@ using DeviceGroupedConvNDFwdInstance = 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -75,14 +75,33 @@ using DeviceGroupedConvNDFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 64, 1, 4>, - 8>; + S<1, 32, 1, 8>, + 4>; #include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float requant_scale = 0.5f; const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; - run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp index 53f810cc9e..2d4ae1f837 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" @@ -55,10 +55,10 @@ using DeviceGroupedConvNDFwdInstance = 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -75,13 +75,32 @@ using DeviceGroupedConvNDFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 64, 1, 4>, - 8>; + S<1, 32, 1, 8>, + 4>; #include "run_conv2d_fwd_perchannel_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + const auto out_element_op = OutElementOp{ActivationOp{}}; - run_conv2d_fwd_perchannel_quantization_example(out_element_op); + run_conv2d_fwd_perchannel_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp index 9db6e201dd..79b0c00fa5 100644 --- a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" @@ -50,10 +50,10 @@ using DeviceGroupedConvNDFwdInstance = 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -70,14 +70,33 @@ using DeviceGroupedConvNDFwdInstance = 1, // BBlockLdsExtraN 1, 1, - S<1, 64, 1, 4>, - 16>; + S<1, 32, 1, 8>, + 4>; #include "run_conv2d_fwd_perlayer_quantization_example.inc" -int main() +int main(int argc, char* argv[]) { + bool do_verification = true; + bool time_kernel = false; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + float requant_scale = 0.5f; const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; - run_conv2d_fwd_perlayer_quantization_example(out_element_op); + run_conv2d_fwd_perlayer_quantization_example(out_element_op, do_verification, time_kernel); } diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index e5b924ad51..3c089688cf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -167,10 +167,10 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op) +int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op, + bool do_verification, + bool time_kernel) { - bool do_verification = true; - bool time_kernel = true; const ck::index_t ndim_spatial = 2; ck::utils::conv::ConvParam conv_param{ @@ -214,7 +214,8 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_ 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto requant_scale_g_k_desc = bias_g_k_desc; diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 9f3a769dcf..ed7886e76b 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -155,10 +155,10 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op) +int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op, + bool do_verification, + bool time_kernel) { - bool do_verification = true; - bool time_kernel = true; const ck::index_t ndim_spatial = 2; ck::utils::conv::ConvParam conv_param{ @@ -201,7 +201,8 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 9b08fc690d..12fdf425bf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -157,10 +157,10 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op) +int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op, + bool do_verification, + bool time_kernel) { - bool do_verification = true; - bool time_kernel = true; const ck::index_t ndim_spatial = 2; ck::utils::conv::ConvParam conv_param{ @@ -203,7 +203,8 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme 1, // k 0, // ho 0 // wo - }); + }, + RequantScaleLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 267c737e00..eae6e996cc 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -139,10 +139,10 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op) +int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element_op, + bool do_verification, + bool time_kernel) { - bool do_verification = true; - bool time_kernel = false; const ck::index_t ndim_spatial = 2; ck::utils::conv::ConvParam conv_param{ diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp index e37d413695..ba589ec044 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -73,11 +73,11 @@ using DeviceBatchedGemmGemmInstance = 8, // AK1 8, // BK1 4, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -102,8 +102,16 @@ using DeviceBatchedGemmGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock #include "run_grouped_conv_conv_fwd_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + // disable on gfx11 due to precsion issue. + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp index 496e676a40..847859068f 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -73,11 +73,11 @@ using DeviceBatchedGemmGemmInstance = 8, // AK1 8, // BK1 4, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 4, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 8, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -102,7 +102,7 @@ using DeviceBatchedGemmGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // CShuffleBlockTransferScalarPerVector_NPerBlock #include "run_grouped_conv_conv_fwd_example.inc" diff --git a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp index 35d50721dc..9a104dbfab 100644 --- a/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp +++ b/example/41_grouped_conv_conv_fwd/grouped_conv_conv_fwd_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -106,4 +106,11 @@ using DeviceBatchedGemmGemmInstance = #include "run_grouped_conv_conv_fwd_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + return run_grouped_conv_conv_fwd_example(argc, argv) ? 0 : 1; +} diff --git a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc index 0722d497d8..852a9bef88 100644 --- a/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc +++ b/example/41_grouped_conv_conv_fwd/run_grouped_conv_conv_fwd_example.inc @@ -257,7 +257,7 @@ bool run_grouped_conv_conv_fwd(bool do_verification, #endif return ck::utils::check_err( - out1_device, out1_host, "Error: incorrect results!", 1e-5f, 1e-4f); + out1_device, out1_host, "Error: incorrect results!", 1e-3f, 1.5e-3f); } return true; diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index ab6f317bc6..86e1c8ccc8 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -11,21 +11,36 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) ck::index_t G = 64; ck::index_t C = 128; + bool do_verification = true; + bool time_kernel = true; + bool log_kernel = true; + if(argc == 1) { // use default case } - else if(argc == 6) + else if(argc == 4) { - N = std::stoi(argv[1]); - H = std::stoi(argv[2]); - W = std::stoi(argv[3]); - G = std::stoi(argv[4]); - C = std::stoi(argv[5]); + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + log_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + log_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + H = std::stoi(argv[5]); + W = std::stoi(argv[6]); + G = std::stoi(argv[7]); + C = std::stoi(argv[8]); } else { - std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + std::cerr << "arg1 = verify(0=no, 1=yes), arg2 = time kernels(0=no, 1=yes), arg3 = log " + "kernels(0=no, 1=yes), arg4 to 8: N, H, W, G, C" + << std::endl; return 1; } @@ -94,7 +109,8 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) device_instance.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); auto invoker_ptr = device_instance.MakeInvokerPointer(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true, true}); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, log_kernel}); std::size_t num_btype = sizeof(XDataType) * N * H * W * G * C + sizeof(YDataType) * N * H * W * G * C + sizeof(GammaDataType) * G * C + @@ -106,6 +122,7 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) << device_instance.GetTypeString() << std::endl; bool pass = true; + if(do_verification) { Tensor host_y({N, H, W, G, C}); Tensor host_save_mean(HostTensorDescriptor{N, G}); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index ebba88cf41..873982227b 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 2>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -44,12 +46,39 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8, 8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; std::vector nchw = {16, 128, 32, 64}; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + nchw[0] = std::stoi(argv[3]); + nchw[1] = std::stoi(argv[4]); + nchw[2] = std::stoi(argv[5]); + nchw[3] = std::stoi(argv[6]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + printf("arg3-6: N, C, H, W (default 16, 128, 32, 64)\n"); + exit(1); + } + std::array ab_lengths; std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), static_cast(nchw[2] * nchw[3]), @@ -57,11 +86,11 @@ int main() 1}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 2> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 2> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -118,7 +147,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 3ea1aa4bf8..2d689648f2 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple @@ -37,11 +39,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; @@ -56,9 +74,9 @@ int main() static_cast(nhwc[3])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -101,7 +119,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 13c67fce05..6e70a306d3 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -23,6 +23,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -41,11 +43,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; std::array ab_lengths; @@ -60,9 +78,9 @@ int main() static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -121,7 +139,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 0a0f6fec10..632d88e88a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -40,11 +43,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; @@ -60,9 +79,9 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -112,7 +131,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index fc664186be..bd54f1c19c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -40,11 +42,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; std::array ab_lengths; @@ -60,9 +78,9 @@ int main() static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; @@ -123,7 +141,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index a0c416318a..9621d591a9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -40,11 +43,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; @@ -60,9 +79,9 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -111,7 +130,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index c40447e1f9..1ce797e4dd 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -124,10 +124,28 @@ int main(int argc, char* argv[]) ck::index_t M = 1024; ck::index_t K = 1024; - if(argc == 3) + if(argc == 1) { - M = std::stoi(argv[1]); - K = std::stoi(argv[2]); + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else if(argc == 5) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + M = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + printf("arg3-4: M(default=1024), K(default=1024)\n"); + exit(1); } std::array dims = {M, K}; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 050300eed2..be4014f636 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -48,11 +51,27 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -int main() +int main(int argc, char* argv[]) { bool do_verification = true; bool time_kernel = true; + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + std::vector nchw = {16, 128, 32, 64}; std::array ab_lengths; std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -62,13 +81,13 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 3> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 3> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; Tensor& a2 = as[2]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; float gamma = 4.f; @@ -133,7 +152,7 @@ int main() if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index c02d540983..8064809123 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,12 +77,44 @@ void host_elementwise2D(HostTensorC& C, } } -int main() +int main(int argc, char* argv[]) { - bool time_kernel = true; + bool do_verification = true; + bool time_kernel = true; + + if(argc == 1) + { + // use default + } + else if(argc == 3) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + if(argc == 1) + { + // use default case + } + else if(argc == 3) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + } + else + { + std::cerr << "arg1 to 2: M, N" << std::endl; + return 1; + } - ck::index_t M = 48 * 256; - ck::index_t N = 1024; ck::index_t Stride = N; auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -157,6 +189,7 @@ int main() std::cout << "Time elapase is : " << ela_time << " ms . " << std::endl; bool pass = true; + if(do_verification) { std::vector mn = {static_cast(M), static_cast(N)}; diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index 56417b101d..4d73f0c35f 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" @@ -31,7 +31,7 @@ using DeviceOpInstance = ck::tensor_operation::device:: //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>; + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + ProblemSize ps = + problem_size; // make mutable copy because default stride values of 0 need to be updated + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); @@ -123,7 +131,16 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + if(std::is_same_v, ck::half_t> && + std::is_same_v, ck::half_t>) + { + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-3, 1e-3); + } + else + { + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } } return true; diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index 392cb155cb..3e69caf51e 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -18,6 +18,10 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -91,11 +95,11 @@ using DeviceOpInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -120,7 +124,7 @@ using DeviceOpInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec, // MaskingSpecialization 1>; @@ -159,6 +163,12 @@ int main(int argc, char* argv[]) int O = 64; float alpha = 1; + // temp disable on gfx11, d0_gs_ms_ns isn't handled correctly when it is not a constant. + if(ck::is_gfx11_supported()) + { + return 0; + } + if(argc == 1) { // use default case @@ -214,12 +224,12 @@ int main(int argc, char* argv[]) std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Col{}); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides, Row{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 788f38ec52..ef64dd167d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -48,15 +48,16 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Pool3d_fwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ; // DBetaDstVectorSize -int main() +int main(int argc, char* argv[]) { bool time_kernel = false; @@ -110,6 +110,25 @@ int main() ck::index_t G = 32; ck::index_t C = 64; + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + N = std::stoi(argv[1]); + H = std::stoi(argv[2]); + W = std::stoi(argv[3]); + G = std::stoi(argv[4]); + C = std::stoi(argv[5]); + } + else + { + std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + + return 1; + } + Tensor dy({N, H, W, G, C}); Tensor x({N, H, W, G, C}); Tensor gamma({G, C}); diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index a9e0d3f9ad..ffc6cec61d 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,3 +1,7 @@ +add_example_executable(example_gemm_multi_ABD_wmma_fp16 gemm_multi_ABD_wmma_fp16.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp) +add_example_executable(example_gemm_multi_ABD_wmma_fastgelu_bf16_i8 gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) add_example_executable(example_gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp) add_example_executable(example_gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8 gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp) diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..a30314f58c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..086a0f4834 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp new file mode 100644 index 0000000000..32345d1263 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 256, + 128, + 32, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 64, 1, 4>, + S<8, 8, 8>>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 12: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..00e2d7e33c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using D1Layout = D0Layout; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 2; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index 5f3bba922f..405eac7df1 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -67,7 +67,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; // clang-format on int main(int argc, char* argv[]) @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{StrideD}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index 95cf8f3674..50e670bdf3 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -67,7 +67,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; // clang-format on int main(int argc, char* argv[]) @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp index 2582ea8a11..2a44c8ad2a 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp @@ -127,10 +127,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -148,7 +148,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; int main(int argc, char* argv[]) { diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 07b9db4620..50e1c21c8f 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -66,7 +66,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl ///######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| ///######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; + < AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 4, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v4>; // clang-format on int main(int argc, char* argv[]) @@ -80,10 +80,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -119,23 +120,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) K, std::array{StrideA}, std::array{StrideB}, - std::array{0, StrideD}, + std::array{StrideB1, StrideD}, StrideE, a_element_op, b_element_op, @@ -261,7 +270,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n)); + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(m, n), d_m_n(m, n)); } } diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 57e2feb084..a9a30b4c27 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -94,10 +97,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -115,7 +118,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultiple 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; int main(int argc, char* argv[]) { @@ -160,12 +163,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -264,9 +267,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -299,7 +302,6 @@ int main(int argc, char* argv[]) auto ref_op = ReferenceOpInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - Tensor empty_tensor(std::vector{}, std::vector{}); auto ref_argument = ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index ec1b2d6018..4f7414abfa 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -140,12 +143,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); - Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -246,9 +249,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -266,7 +269,7 @@ int main(int argc, char* argv[]) } } - Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) { diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt index b9584be89c..f23f908883 100644 --- a/example/62_convnd_activ/binary/CMakeLists.txt +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -1,15 +1,9 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_binary_xdl) - # Bilinear residual - add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) - add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16) - add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp) - add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16) - add_example_executable(example_convnd_bwd_weight_xdl_bilinear_residual_fp16 convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp) - add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_weight_xdl_bilinear_residual_fp16) - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_binary_xdl) +# Bilinear residual +add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) +add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16) +add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp) +add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16) +add_example_executable(example_convnd_bwd_weight_xdl_bilinear_residual_fp16 convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp) +add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_weight_xdl_bilinear_residual_fp16) + diff --git a/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp index f5bddf2302..2710dd6b63 100644 --- a/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp +++ b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -70,10 +70,10 @@ using DeviceGroupedConvNDBwdDataInstance = 32, // KPerBlock 8, // AK1 2, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -91,7 +91,7 @@ using DeviceGroupedConvNDBwdDataInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance; diff --git a/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp index fa3edc5adc..cb37ebf575 100644 --- a/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp +++ b/example/62_convnd_activ/binary/convnd_bwd_weight_xdl_bilinear_residual_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -63,10 +63,10 @@ using DeviceGroupedConvNDBwdWeightInstance = 128, // NPerBlock 4, // K0PerBlock 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 4, // NXdlPerWave S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder @@ -84,7 +84,7 @@ using DeviceGroupedConvNDBwdWeightInstance = 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 128 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl + 64 / (sizeof(WeiDataType) * CHAR_BIT)>; // CBlockTransferScalarPerVector_NWaveNPerXdl using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdWeightInstance; namespace { @@ -257,4 +257,12 @@ bool run_grouped_conv(bool do_verification, #include "../run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable test on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return !run_convnd_example(argc, argv); +} diff --git a/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp index ae1ebcb2cd..616d0cc9e8 100644 --- a/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp +++ b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; diff --git a/example/62_convnd_activ/convinvscale/CMakeLists.txt b/example/62_convnd_activ/convinvscale/CMakeLists.txt index 7aae090674..c737bc00ec 100644 --- a/example/62_convnd_activ/convinvscale/CMakeLists.txt +++ b/example/62_convnd_activ/convinvscale/CMakeLists.txt @@ -1,10 +1,5 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convinvscale) - add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) - set(target 1) - endif() -endforeach() \ No newline at end of file +if (NOT GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_convnd_activ_xdl_convinvscale) + add_example_executable(example_convnd_fwd_xdl_convinvscale_fp8 convnd_fwd_xdl_convinvscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convinvscale example_convnd_fwd_xdl_convinvscale_fp8) +endif() \ No newline at end of file diff --git a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp index fbdfc72063..2194c536c0 100644 --- a/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp +++ b/example/62_convnd_activ/convinvscale/convnd_fwd_xdl_convinvscale_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convinvscale_common.hpp" @@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp index d101fd59bd..0a802ee27d 100644 --- a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -74,10 +74,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -95,7 +95,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; @@ -130,11 +130,12 @@ bool run_grouped_conv(bool do_verification, // Fill other lenghts than G,K with 1 and strides with 0 bias_g_k_lengths.fill(1); bias_g_k_strides.fill(0); - bias_g_k_lengths[0] = G; - bias_g_k_lengths[2] = K; - bias_g_k_strides[0] = K; // stride to G - bias_g_k_strides[2] = 1; // stride to K - const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = + HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides, BiasLayout{}); // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) Tensor in(in_g_n_c_wis_desc); diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp index f784655cc5..3266c55d7c 100644 --- a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -71,10 +71,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; diff --git a/example/62_convnd_activ/convscale/CMakeLists.txt b/example/62_convnd_activ/convscale/CMakeLists.txt index 26f6c1b168..8746a5ad54 100644 --- a/example/62_convnd_activ/convscale/CMakeLists.txt +++ b/example/62_convnd_activ/convscale/CMakeLists.txt @@ -1,20 +1,14 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convscale) - add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 ) +if (NOT GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_convnd_activ_xdl_convscale) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8 ) - add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8) + add_example_executable(example_convnd_fwd_xdl_convscale_bf8 convnd_fwd_xdl_convscale_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8) - add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) + add_example_executable(example_convnd_fwd_xdl_convscale_fp8_bf8 convnd_fwd_xdl_convscale_fp8_bf8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8_bf8) - add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) - - set(target 1) - endif() -endforeach() + add_example_executable(example_convnd_fwd_xdl_convscale_bf8_fp8 convnd_fwd_xdl_convscale_bf8_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_bf8_fp8) +endif() diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp index c1c8c3a57f..f7ad53221c 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_common.hpp" @@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp index 8590d0620f..6f0337b85e 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_bf8_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_common.hpp" @@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp index a7d69ccffc..7046c93f9f 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_common.hpp" @@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp index ab59e08a80..3376b9aba3 100644 --- a/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp +++ b/example/62_convnd_activ/convscale/convnd_fwd_xdl_convscale_fp8_bf8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_common.hpp" @@ -58,10 +58,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -79,7 +79,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale_add/CMakeLists.txt b/example/62_convnd_activ/convscale_add/CMakeLists.txt index b2e0eecb58..5dac630298 100644 --- a/example/62_convnd_activ/convscale_add/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_add/CMakeLists.txt @@ -1,11 +1,5 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convscale_add) - add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8 ) - - set(target 1) - endif() -endforeach() +if (NOT GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_convnd_activ_xdl_convscale_add) + add_example_executable(example_convnd_fwd_xdl_convscale_add_fp8 convnd_fwd_xdl_convscale_add_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_add example_convnd_fwd_xdl_convscale_add_fp8) +endif() \ No newline at end of file diff --git a/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp index 3f592b2c54..71dddcfe91 100644 --- a/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp +++ b/example/62_convnd_activ/convscale_add/convnd_fwd_xdl_convscale_add_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/tuple.hpp" #include "convnd_fwd_convscale_add_common.hpp" @@ -57,10 +57,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -78,7 +78,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt index 739c855ae4..c1c64671b4 100644 --- a/example/62_convnd_activ/convscale_reduce/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_reduce/CMakeLists.txt @@ -1,14 +1,8 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convscale_reduce) - add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8) +if (NOT GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_convnd_activ_xdl_convscale_reduce) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_amax_fp8 convnd_fwd_xdl_convscale_relu_amax_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_relu_amax_fp8) - add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8) - - set(target 1) - endif() -endforeach() + add_example_executable(example_convnd_fwd_xdl_convscale_amax_fp8 convnd_fwd_xdl_convscale_amax_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_reduce example_convnd_fwd_xdl_convscale_amax_fp8) +endif() \ No newline at end of file diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp index a8b4fdbead..7f0b2329f6 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_amax_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_reduce_common.hpp" @@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp index df6bf7bd5c..9a7de75d00 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_xdl_convscale_relu_amax_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_reduce_common.hpp" @@ -52,10 +52,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -73,7 +73,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/convscale_relu/CMakeLists.txt b/example/62_convnd_activ/convscale_relu/CMakeLists.txt index c3241aecf2..024b79e2af 100644 --- a/example/62_convnd_activ/convscale_relu/CMakeLists.txt +++ b/example/62_convnd_activ/convscale_relu/CMakeLists.txt @@ -1,11 +1,5 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl_convscale_relu) - add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) - add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8 ) - - set(target 1) - endif() -endforeach() +if (NOT GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_convnd_activ_xdl_convscale_relu) + add_example_executable(example_convnd_fwd_xdl_convscale_relu_fp8 convnd_fwd_xdl_convscale_relu_fp8.cpp) + add_example_dependencies(example_convnd_activ_xdl_convscale_relu example_convnd_fwd_xdl_convscale_relu_fp8) +endif() diff --git a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp index 360349e7ec..4fac49133c 100644 --- a/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp +++ b/example/62_convnd_activ/convscale_relu/convnd_fwd_xdl_convscale_relu_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_convscale_relu_common.hpp" @@ -56,10 +56,10 @@ using DeviceGroupedConvNDFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -77,7 +77,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8, + 4, AComputeDataType, BComputeDataType>; diff --git a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt index 8441030945..359b444dd0 100644 --- a/example/62_convnd_activ/dynamic_unary/CMakeLists.txt +++ b/example/62_convnd_activ/dynamic_unary/CMakeLists.txt @@ -1,45 +1,37 @@ -list(APPEND gpu_list gfx908 gfx90a gfx942 gfx950) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_dynamic_unary_xdl) - # Sigmoid - add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16) - # Tanh - add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16) - # Relu - add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16) - # SoftRelu - add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16) - # Abs - add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16) - # Pow - add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16) - # Clipped Relu - add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16) - # Leaky Relu - add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16) - # Elu - add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16) - # Swish - add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16) - # PassThrough - add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16) - # Logistic - add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp) - add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) - - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_dynamic_unary_xdl) +# Sigmoid +add_example_executable(example_convnd_fwd_xdl_dynamic_sigmoid_fp16 convnd_fwd_xdl_dynamic_sigmoid_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_sigmoid_fp16) +# Tanh +add_example_executable(example_convnd_fwd_xdl_dynamic_tanh_fp16 convnd_fwd_xdl_dynamic_tanh_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_tanh_fp16) +# Relu +add_example_executable(example_convnd_fwd_xdl_dynamic_relu_fp16 convnd_fwd_xdl_dynamic_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_relu_fp16) +# SoftRelu +add_example_executable(example_convnd_fwd_xdl_dynamic_softrelu_fp16 convnd_fwd_xdl_dynamic_softrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_softrelu_fp16) +# Abs +add_example_executable(example_convnd_fwd_xdl_dynamic_abs_fp16 convnd_fwd_xdl_dynamic_abs_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_abs_fp16) +# Pow +add_example_executable(example_convnd_fwd_xdl_dynamic_pow_fp16 convnd_fwd_xdl_dynamic_pow_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_pow_fp16) +# Clipped Relu +add_example_executable(example_convnd_fwd_xdl_dynamic_clippedrelu_fp16 convnd_fwd_xdl_dynamic_clippedrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_clippedrelu_fp16) +# Leaky Relu +add_example_executable(example_convnd_fwd_xdl_dynamic_leakyrelu_fp16 convnd_fwd_xdl_dynamic_leakyrelu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_leakyrelu_fp16) +# Elu +add_example_executable(example_convnd_fwd_xdl_dynamic_elu_fp16 convnd_fwd_xdl_dynamic_elu_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_elu_fp16) +# Swish +add_example_executable(example_convnd_fwd_xdl_dynamic_swish_fp16 convnd_fwd_xdl_dynamic_swish_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_swish_fp16) +# PassThrough +add_example_executable(example_convnd_fwd_xdl_dynamic_passthrough_fp16 convnd_fwd_xdl_dynamic_passthrough_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_passthrough_fp16) +# Logistic +add_example_executable(example_convnd_fwd_xdl_dynamic_logistic_fp16 convnd_fwd_xdl_dynamic_logistic_fp16.cpp) +add_example_dependencies(example_convnd_activ_dynamic_unary_xdl example_convnd_fwd_xdl_dynamic_logistic_fp16) \ No newline at end of file diff --git a/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp index ed31be19ee..4af7f4535a 100644 --- a/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp +++ b/example/62_convnd_activ/dynamic_unary/convnd_fwd_activ_dynamic_unary_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -71,10 +71,10 @@ using DeviceGroupedConvNDActivInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -92,7 +92,7 @@ using DeviceGroupedConvNDActivInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; template #include @@ -68,10 +68,10 @@ using DeviceGroupedConvNDMultiABFwdInstance = 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 8, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -89,7 +89,7 @@ using DeviceGroupedConvNDMultiABFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; namespace { template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -92,7 +92,7 @@ using DeviceGroupedConvNDFwdInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; template a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] - Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + Tensor scale_k_n( + HostTensorDescriptor({K, N}, {0, 1_uz}, ck::tensor_layout::BypassLayoutVerification())); switch(config.init_method) { diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 2d9c794faa..74930d2b21 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -16,7 +16,7 @@ add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) -list(APPEND gpu_list gfx942 gfx950) +list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx11-generic gfx12-generic) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) diff --git a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp index 086ea45d10..fe8fd9c100 100644 --- a/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp +++ b/example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp @@ -78,11 +78,17 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| ///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| ///###### RCR - < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 4>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) { + // fp8 are not supported on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + bool do_verification = true; int init_method = 1; bool time_kernel = false; diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp index 69803c7eeb..8b8cee9e52 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -97,11 +97,12 @@ struct MultiplyMultiply } }; +static constexpr int KPack = 8; + void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl) { - int KPack = 16 / sizeof(F16); int NLane = NXdl; - int KLane = 64 / NLane; + int KLane = ck::get_warp_size() / NLane; int K0 = K / (KLane * KPack); // K -> K0 KLane KPack @@ -147,12 +148,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, - 8, 8, - 32, 32, - 1, 1, + KPack, KPack, + 16, 16, + 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, - 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>; // clang-format on @@ -211,6 +212,12 @@ int main(int argc, char* argv[]) exit(0); } + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; @@ -234,6 +241,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -278,8 +307,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -301,7 +328,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 352d373ae5..8da49ef85d 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -162,6 +162,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -202,8 +224,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; auto invoker = device_op.MakeInvoker(); @@ -218,7 +238,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 5aa978fbf0..3b21f95119 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -91,6 +91,8 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; + ck::index_t KBatch = 1; + if(argc == 1) { // use default case @@ -101,7 +103,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 8) + else if(argc == 8 || argc == 9) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -113,6 +115,11 @@ int main(int argc, char* argv[]) flush_cache = std::stoi(argv[7]); + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + StrideA = K; StrideB = K; StrideE = N; @@ -124,6 +131,7 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 6: M, N, K\n"); printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); exit(0); } @@ -233,9 +241,9 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{}, e_device_buf.GetDeviceBuffer(), @@ -251,6 +259,7 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, cde_element_op); + argument.KBatch = KBatch; if(!device_op.IsSupportedArgument(argument)) { diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index fe1eca51b0..3ee4955ae4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -251,6 +251,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -295,8 +317,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -318,7 +338,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp index cbbd37408e..cc01d01e64 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -125,11 +125,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 64, 128, 256, 16, 16, - 32, 32, - 1, 2, + 16, 16, + 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, I8>; // clang-format on diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 9fe9fdde78..72ea7f1cb6 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -168,7 +168,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t EVec = 8 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul @@ -287,15 +287,18 @@ int main(int argc, char* argv[]) } } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; @@ -422,7 +425,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( - {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; @@ -463,7 +467,7 @@ int main(int argc, char* argv[]) Tensor b_e_n_k({experts, K, N * 2}); e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); // handle scale before ref. for(int t = 0; t < tokens; ++t) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index f78e6e48a5..5e306ac6dd 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -121,6 +121,7 @@ struct MulABScaleExpertWeight }; static constexpr bool MulRoutedWeight = true; +static constexpr ck::index_t KPack = 32; using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true @@ -129,7 +130,6 @@ using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) { - int KPack = 32; int NLane = NXdl; int KLane = 64 / NLane; @@ -169,18 +169,19 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul + // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, MPerBlock, 64, 128, - 16, 32, + 16, KPack, 16, 16, - 8, 1, + 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + 2, 1, S<1, 32, 1, 8>, S<4, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; // clang-format on @@ -263,15 +264,18 @@ int main(int argc, char* argv[]) } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; @@ -458,9 +462,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl; } if(time_kernel) @@ -486,7 +491,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d0_t_n( - HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 354957c0d1..cc42c4b815 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -292,17 +292,19 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k(HostTensorDescriptor( {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 3745e3d0af..29e758f9d4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -29,8 +29,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = I4; @@ -85,11 +86,11 @@ struct MulABScaleExpertWeight } }; -using CDEElementOp = MulABScaleExpertWeight; +using CDEElementOp = MulABScaleExpertWeight; +static constexpr int KPack = 32 / sizeof(B0DataType); void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) { - int KPack = 32; int NLane = NXdl; int KLane = 64 / NLane; @@ -135,7 +136,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); +static constexpr ck::index_t BK1 = KPack; static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; @@ -239,10 +240,10 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); @@ -414,9 +415,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) { - std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + std::cout << "This kernel support gfx942, gfx950, gfx11 and gfx12 only" << std::endl; } if(time_kernel) diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp index 480ca5a0af..ed1c1dc303 100644 --- a/example/66_complex_contraction_bilinear/common_instances.hpp +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 2, ComputeDataType>; // clang-format on template a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // For Imaginary Part of Complex Tensor - Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // Intermediate E tensor Definition - Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; @@ -349,8 +350,10 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { // Real Part Verification - Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_re1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_img1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); auto ref_argument_img = ref_op.MakeArgument( a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index aaf0cb3891..69c0d6558f 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -269,10 +269,12 @@ int main(int argc, char* argv[]) Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -281,12 +283,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -480,7 +483,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -278,12 +280,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -477,7 +480,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -310,12 +313,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -506,7 +510,7 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 829bf9af24..5bb6454d2a 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -268,16 +268,18 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index efbd0f0c03..333f8a3d52 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -303,16 +303,18 @@ int main(int argc, char* argv[]) expert_ids.savetxt("expert_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -321,7 +323,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/68_gemm_add/gemm_add_xdl_bf16.cpp b/example/68_gemm_add/gemm_add_xdl_bf16.cpp index 284e424c14..8861ad9cad 100644 --- a/example/68_gemm_add/gemm_add_xdl_bf16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_bf16.cpp @@ -54,10 +54,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -75,7 +75,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_gemm_add_example_xdl.inc" diff --git a/example/68_gemm_add/gemm_add_xdl_fp16.cpp b/example/68_gemm_add/gemm_add_xdl_fp16.cpp index 4ba10e9d3b..0f21415311 100644 --- a/example/68_gemm_add/gemm_add_xdl_fp16.cpp +++ b/example/68_gemm_add/gemm_add_xdl_fp16.cpp @@ -54,10 +54,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -75,7 +75,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_gemm_add_example_xdl.inc" diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp index b5a84cd828..ac5586764c 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_bf16.cpp @@ -54,10 +54,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -75,7 +75,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_gemm_add_relu_example_xdl.inc" diff --git a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp index 9e91641ba4..f9c963b4df 100644 --- a/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp +++ b/example/69_gemm_add_relu/gemm_add_relu_xdl_fp16.cpp @@ -54,10 +54,10 @@ using DeviceOpInstance = 32, 8, 8, - 32, - 32, + 16, + 16, + 8, 4, - 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -75,7 +75,7 @@ using DeviceOpInstance = 1, 1, S<1, 32, 1, 8>, - 8>; + 4>; #include "run_gemm_add_relu_example_xdl.inc" diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 7bd628edf2..940e7bc5e6 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -69,7 +69,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() #Do not build any XDL examples if gfx9 targets are not on the list - if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -93,8 +93,8 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 and gfx12 + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95" AND NOT EX_TARGETS MATCHES "gfx12") if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)") message(DEBUG "Skipping ${source} example for current target") list(REMOVE_ITEM FILE_NAME "${source}") @@ -109,14 +109,14 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endforeach() if(FILE_NAME) if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic) elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 gfx950 and rdna3/4 message(DEBUG "trimming targets for ${FILE_NAME}") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx10-3-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -192,7 +192,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) list(REMOVE_ITEM FILE_NAME "${source}") endif() #Do not build any XDL examples if gfx9 targets are not on the list - if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") + if(NOT EX_TARGETS MATCHES "gfx9" AND NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() @@ -206,7 +206,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(source_name_list MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx10-3-generic) elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index b1e2373657..b8ca26193d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -26,7 +26,7 @@ endforeach() # "fwd" is a must-have api for the fmha_fwd example, add it if not specified if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS) - list(APPEND FMHA_FWD_ENABLE_APIS "fwd") + list(PREPEND FMHA_FWD_ENABLE_APIS "fwd") endif() file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS @@ -47,10 +47,19 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 - --optdim 32,64,128,256 + --optdim 32,64,96,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) +# Reduce building time by disabling instances that are not currently used in the gtests +# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of +# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when +# there is no corresponding instance for parameters). +if(BUILD_TESTING) + # Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill + list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv) +endif() + # generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS} @@ -160,6 +169,10 @@ if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + target_compile_options(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) @@ -204,8 +217,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps ) -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index cb6cd44f64..2b872cb9b5 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -36,6 +36,13 @@ args: total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) @@ -76,11 +83,20 @@ args: -repeat number of iterations to benchmark the kernel (default:20) -json 0: No Json, 1: Dump Results in Json format (default:0) -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -128,7 +144,16 @@ Note FA use bottom-right by default to express swa case, here we require you exp ### dropout TBD +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. -Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 75978d6a7f..b6f491a0ee 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,16 +1,19 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" } BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16" } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 0d8f366d8a..e2f69fa49a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8f710050b1..059be0e490 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout}; @@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, - {F_dvpad}>>; + ({F_dvpad} > 0)>>; using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; @@ -220,9 +214,9 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; r = fmha_bwd_>(s, a); return r; }} @@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_dpad : str # - F_dvpad : str # + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] F_bias : str # F_dbias : str # F_dropout : str # @@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = DROPOUT_MAP[self.F_dropout], @@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel: def name(self) -> str: def pad_name() -> str: n = '' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' if n != '' : n = 'p' + n return n pn = pad_name() @@ -376,18 +370,30 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), ] elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': return [ + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + + # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), + FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), ] @@ -621,8 +627,8 @@ class FmhaBwdApiTrait: dbias : str dropout : str spad1d : str # spad for 1d kernels (dot/convert) - dpad : str - dvpad : str + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] deterministic : str mask_impl : str tr_load : str @@ -651,13 +657,13 @@ class FmhaBwdApiTrait: @property def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' @property def extra_cond(self) -> str: @@ -677,8 +683,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dvpad = 't' if self.dvpad else 'f' return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: @@ -693,8 +700,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dpad = 't' if self.dpad else 'f' return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -720,7 +728,7 @@ class FmhaBwdApiPool: F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, F_convert_dq_bn0=trait.convert_dq_bn0) @@ -793,7 +801,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -804,8 +815,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - if tr_load == "t" and (dpad == "t" or dvpad == "t"): - continue # tr_load cannot work with dpad or dvpad + if tr_load == "t": + # tr_load can only work with 8 pad + if dpad != dvpad or dpad == 1: + continue + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue if optdim_list != [-1]: if hdim not in optdim_list: continue @@ -861,6 +878,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True if not t.convert_dq_kernel.disabled: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 14d3b84c74..6fc808c3ef 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -25,6 +25,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -163,8 +164,8 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - - const bool has_load_tr = ck_tile::is_load_tr_supported(); + + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; @@ -263,7 +264,7 @@ class FmhaFwdApiTrait: else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qs','qr_wholek_prefetch']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' @@ -351,7 +352,7 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - + if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -378,7 +379,7 @@ class FmhaFwdApiPool: "t": "has_load_tr", "f": "true" } - + per_tr_load =str() for tr_load in ["t", "f"]: per_dtypes=str() @@ -386,6 +387,7 @@ class FmhaFwdApiPool: per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -393,7 +395,7 @@ class FmhaFwdApiPool: F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -534,7 +536,20 @@ class KernelComponentFactory: # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -551,12 +566,16 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } else: return None @@ -568,9 +587,15 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: + squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) @@ -593,11 +618,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_wholek_prefetch', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -625,6 +651,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.tag != 'qr_wholek_prefetch' and (pipeline.F_spad != 't' or pipeline.F_skpad != 't'): @@ -682,27 +710,61 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 + if not cond: + continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0ebeaddf9c..38491b56c4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @dataclass @@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op cond &= pipeline.F_vlayout == 'row' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3b48b3d005..281357ef1e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -769,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -835,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7b93e9654c..3624b7b387 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) elif dtype in ['fp8', 'bf8']: - # TODO - None + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index e0e1fba668..73b3c1e619 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[]) "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -159,7 +159,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == bwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c3bbb7a558..c27a5ce1ae 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -44,21 +48,15 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") + "note when squant=1, this value will be modified") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("squant", "auto", "if using static quantization fusion or not. auto: fp8 will default use squant, " "other will not\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") + "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -69,7 +67,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -89,7 +87,7 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + "\n tf or 2 - trig float\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -113,7 +111,15 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -133,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -148,11 +157,6 @@ auto run(const ck_tile::ArgParser& arg_parser) uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); bool drop_prefs = arg_parser.get_bool("drop_prefs"); std::string mask_str = arg_parser.get_str("mask"); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); @@ -185,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, + seqlen_qpads, seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, @@ -201,11 +208,6 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - range_q, - range_k, - range_v, - range_p, - range_o, squant, is_rotary_interleaved, num_splits, @@ -225,7 +227,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == fwd_result::success ? 0 : -2; } @@ -237,6 +243,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index d2428e5152..7ddb65a2db 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -45,25 +45,23 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair(hdim)); - mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + const auto is_causal = args.get_bool("causal"); + if(is_causal) + { + mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + } + else + { + mask = mask_info::decode("0", seqlen_q, seqlen_k); + } input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; + q_eff_lens = args.get_int_vec("q_eff_lens"); + kv_eff_lens = args.get_int_vec("kv_eff_lens"); } std::vector get_query_shape() const @@ -174,6 +183,8 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; }; struct RunConfig @@ -328,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args; + ck_tile::fmha_fwd_v3_args args{}; args.data_type = problem.data_type; args.batch = problem.batch; @@ -382,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -444,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..6cd1cd94fa 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -15,6 +15,10 @@ #include #include +struct FmhaBwdFp32 +{ +}; + struct FmhaBwdFp16 { }; @@ -26,6 +30,26 @@ struct FmhaBwdBf16 template struct FmhaBwdTypeConfig; +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using GemmDataType = float; + using BiasDataType = float; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = float; + using OGradDataType = float; + using QGradDataType = float; + using KGradDataType = float; + using VGradDataType = float; + using BiasGradDataType = float; +}; + template <> struct FmhaBwdTypeConfig { @@ -368,8 +392,8 @@ template +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-4; + double atol = 1e-4; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { @@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -763,15 +773,21 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type != mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..761def6d6a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,10 @@ #include #include +struct FmhaFwdFp32 +{ +}; + struct FmhaFwdFp16 { }; @@ -41,9 +45,29 @@ struct FmhaFwdFp8Bf16 { }; +struct FmhaFwdFp8Fp32 +{ +}; + template struct FmhaFwdTypeConfig; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = float; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + template <> struct FmhaFwdTypeConfig { @@ -108,6 +132,38 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; @@ -126,11 +182,20 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; + // Optional cumulative sequence length arrays + // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + // Group mode: seqstart_padded_* provide physical starts including PAD (optional) + const void* seqstart_padded_q_ptr = nullptr; // [batch+1] + const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -518,7 +583,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.seqstart_padded_q_ptr, + args.seqstart_padded_k_ptr); } else { // create batch mode kernel arguments @@ -564,7 +631,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 397245ab32..0703af71e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(std::string /*init_method*/) { @@ -50,20 +58,30 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); } int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) @@ -141,7 +159,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -157,11 +178,6 @@ fwd_result fmha_fwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, - float range_q, - float range_k, - float range_v, - float range_p, - float range_o, bool squant, bool is_rotary_interleaved, ck_tile::index_t num_splits, @@ -172,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -180,6 +198,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8"; else if constexpr(std::is_same_v) return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; else static_assert(false); }(); @@ -290,6 +312,24 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) + const bool has_group_padding = + (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || + (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); + const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || + !kv_eff_lens_per_batch.empty())); + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + if((using_appendkv || using_pagedkv || using_splitkv) && + (has_group_padding || has_batch_efflens)) + { + std::cerr << "Padding (physical or effective lengths) is not supported with " + "appendkv/splitkv/pagedkv pipelines" + << std::endl; + return fwd_result::invalid_args; + } + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = generate_missing_seqlens(mode, batch, @@ -353,6 +393,44 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -367,22 +445,6 @@ fwd_result fmha_fwd_run(mode_enum mode, using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); - scale_p = p_dtype_max / range_p; - scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); - } - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -452,8 +514,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck_tile::index_t shape_seqlen_q = + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -511,7 +580,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -528,7 +597,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -576,32 +645,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") - { - // suitable for fp8 quantization - if(!squant) - { - std::cerr << "init method " << init_method << " can not be used without quantization" - << std::endl; - return fwd_result::invalid_args; - } - ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); - - // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [0.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); - } - else - { - std::cerr << "Unknown value for init argument: " << init_method << std::endl; - return fwd_result::invalid_args; - } - if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -625,8 +668,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -635,6 +678,16 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -650,15 +703,90 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); - knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() - : seqstart_k_with_padding_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -711,6 +839,54 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cout << ", cache_batch_idx:" << use_cache_batch_idx; } #endif + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_efflens) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + std::cout << std::flush; const auto init_traits = [&](auto& traits) { @@ -794,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -810,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -925,6 +1101,29 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } } else if constexpr(std::is_same_v>) { @@ -964,7 +1163,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - const float appendkv_ave_time = [&] { + auto run_appendkv = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { @@ -974,18 +1173,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_appendkv_args fwd_appendkv_args; init_args(fwd_appendkv_args); - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); } #endif return 0.0f; - }(); + }; + const float appendkv_ave_time = run_appendkv(stream_config); if(appendkv_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return fwd_result::no_instance; } - const float fwd_ave_time = [&] { + auto run_fwd = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -995,8 +1195,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_pagedkv_args fmha_pagedkv_args; init_args(fmha_pagedkv_args); - const float ave_time = - fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); #if CK_TILE_FMHA_FWD_SPLITKV_API // If there is no instance for these args, fallback to fmha_fwd_splitkv if(ave_time >= 0.0f) @@ -1015,7 +1214,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_splitkv_args fmha_splitkv_args; init_args(fmha_splitkv_args); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); } #endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; @@ -1024,8 +1223,9 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); - return fmha_fwd(fmha_traits, fmha_args, stream_config); - }(); + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -1099,11 +1299,24 @@ fwd_result fmha_fwd_run(mode_enum mode, } else { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v; + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; auto p_compute_element_func = [&]() { if constexpr(supports_squant) @@ -1113,9 +1326,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); auto oacc_element_func = [&]() { - if constexpr(supports_squant) + if constexpr(std::is_same_v && supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; else return ck_tile::identity{}; }(); @@ -1127,15 +1342,29 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1498,8 +1727,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 5361d27f0f..4bd1d1a367 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -34,7 +34,8 @@ struct fmha_fwd_v3_args index_t window_size_left; index_t window_size_right; - index_t mask_type; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). const void* q_ptr; index_t stride_q; @@ -55,6 +56,11 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; + + // Optional batch-mode cumulative seqlen overrides (exclude PAD) + // If provided, they override per-batch effective lengths to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index d6e4ac4c60..194675f962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "fmha_fwd_v3.hpp" +#include "mask.hpp" #define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ @@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits -1 // kBlockPerCu >; - using fmha_mask = SimplifiedGenericAttentionMask; + using fmha_mask = GenericAttentionMask; using fmha_pipeline_problem = BlockFmhaFwdV3PipelineProblem::qkvp_dtype, @@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits template float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) { + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, @@ -140,7 +157,10 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..31ad800039 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index 9c500edf9d..a3f7d68eb3 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -8,24 +8,35 @@ for prec in "fp16" "bf16" ; do for hdim in 128 ; do for perm in 0 ; do -if [ $causal -eq 0 ]; then - mask=0 -else - mask=b:-1,0 -fi - -$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID done done done done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index e7babd2744..5c2a5a4b3d 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -34,15 +34,15 @@ function print_log_header(){ } #run verification tests -example/ck_tile/01_fmha/script/smoke_test_fwd.sh -example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" print_log_header $fmha_fwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" print_log_header $fmha_bwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index d123f842a2..cd51dde2d4 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -2,14 +2,46 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_bwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=${GPU_arch:-""} +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1' + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + set -x +# main tests for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 256 ; do @@ -18,20 +50,41 @@ for bias in "n" "a" ; do for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS - -done -done -done -done -done -done -done +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 3913a0d5c2..fca6b8d0cd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -2,12 +2,23 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1 -warmup=0 -repeat=1' # mode=0 # export HIP_VISIBLE_DEVICES=4 @@ -30,6 +41,16 @@ while getopts ":sa" opt; do esac done +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + run_fp16_bf16_tests() { local NUM_SPLITS="1" local PAGE_BLOCK_SIZE="0" @@ -52,16 +73,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done @@ -73,7 +94,29 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } @@ -88,19 +131,151 @@ run_fp16_appendkv_tests() { for page_block_size in 0 128 ; do for cache_batch_idx in 0 1 ; do - $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done } +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + set -x run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests fi set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index b7512b2999..5f589db8d0 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp index 0455e8e34d..b4e0df711b 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp @@ -6,6 +6,7 @@ #include "run_gemm_example_common.hpp" #include "gemm_splitk_two_stage_invoker.hpp" +template