diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1d0f7df3c6..af36f492ba 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd # Documentation files docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD *.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 557afe2d84..cc66fdbfe8 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -42,6 +42,24 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: file=sys.stderr, ) return None + +GITHUB_WORKFLOWS_CI_PATTERNS = [ + "therock*", +] + +def is_path_workflow_file_related_to_ci(path: str) -> bool: + return any( + fnmatch.fnmatch(path, ".github/workflows/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) or any( + fnmatch.fnmatch(path, ".github/scripts/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) + +def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool: + if paths is None: + return False + return any(is_path_workflow_file_related_to_ci(p) for p in paths) # Paths matching any of these patterns are considered to have no influence over # build or test workflows so any related jobs can be skipped if all paths @@ -82,12 +100,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: ) other_paths = paths_set - github_workflows_paths + related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths) contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) print("should_ci_run_given_modified_paths findings:") print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") - if contains_other_non_skippable_files: + if related_to_ci: + print("Enabling build jobs since a related workflow file was modified") + return True + elif contains_other_non_skippable_files: print("Enabling TheRock CI jobs since a non-skippable path was modified") return True else: diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 7db124d2a1..695fb1d913 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -27,30 +27,35 @@ jobs: TEATIME_FORCE_INTERACTIVE: 0 AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini steps: + - name: "Checking out repository for rocm-libraries" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/rocm-libraries" + - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: "composable_kernel" - name: Checkout TheRock repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + ref: 409f43ad9d564454bb1b23f8c8aa15d6b9d25200 path: "TheRock" - name: Runner Health Settings run: | - df -h - cmake --version - echo "Installed Python versions:" - ls -d /opt/python - echo "python: $(which python), python3: $(which python3)" - echo "Git version: $(git --version)" - git config --global --add safe.directory $PWD - git config fetch.parallel 10 + ./TheRock/build_tools/health_status.py - name: Fetch sources run: | - ./TheRock/build_tools/fetch_sources.py --jobs 12 + ./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks + + - name: Patch rocm-libraries + run: | + git config --global --add safe.directory '*' + git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | @@ -92,32 +97,14 @@ jobs: aws-region: us-east-2 role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external - - name: Create Logs index Files and upload logs + - name: Post Build Upload if: always() run: | - python3 TheRock/build_tools/github_actions/create_log_index.py \ - --build-dir=TheRock/build \ - --amdgpu-family=${{ env.AMDGPU_FAMILIES }} - - python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ - --build-dir=TheRock/build \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} - - - name: Upload artifacts - run: | - python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build - - - name: Add Links to Job Summary - if: always() - run: | - python TheRock/build_tools/github_actions/upload_build_summary.py \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build + --build-dir TheRock/build \ + --upload therock-test-linux: name: "Test" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 3232652b6b..40a3b0bec8 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -56,7 +56,14 @@ jobs: uses: ./.github/workflows/therock-ci-linux.yml secrets: inherit with: - cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + cmake_options: >- + -DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON + -DTHEROCK_ENABLE_MIOPEN=ON + -DTHEROCK_ENABLE_ALL=OFF + -DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON + -DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel + -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON + -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" test_runs_on: "linux-mi325-1gpu-ossci-rocm" diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml new file mode 100644 index 0000000000..674e93c1de --- /dev/null +++ b/.github/workflows/therock-test-component.yml @@ -0,0 +1,71 @@ +name: Test component + +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + component: + type: string + + +permissions: + contents: read + +jobs: + test_component: + name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})' + runs-on: ${{ inputs.test_runs_on }} + container: + image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }} + options: --ipc host + --group-add video + --device /dev/kfd + --device /dev/dri + --group-add 992 + --env-file /etc/podinfo/gha-gpu-isolation-settings + strategy: + fail-fast: false + matrix: + # The shard array is based on "total_shards" from "fetch_test_configurations.py" + # The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards) + shard: ${{ fromJSON(inputs.component).shard_arr }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}" + OUTPUT_ARTIFACTS_DIR: "./build" + THEROCK_BIN_DIR: "./build/bin" + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + steps: + - name: Checkout Repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} + + - name: Test + timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }} + env: + SHARD_INDEX: ${{ matrix.shard }} + TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }} + run: | + ${{ fromJSON(inputs.component).test_script }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 37ddd399ad..54e068eb3d 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -37,41 +37,17 @@ jobs: test_components: name: 'Test ${{ matrix.components.job_name }}' - runs-on: ${{ inputs.test_runs_on }} - needs: configure_test_matrix + needs: [configure_test_matrix] # skip tests if no test matrix to run if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} strategy: fail-fast: false matrix: components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} - defaults: - run: - shell: bash - env: - VENV_DIR: ${{ github.workspace }}/.venv - ARTIFACT_RUN_ID: "${{ github.run_id }}" - OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build - THEROCK_BIN_DIR: "./build/bin" - steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: "ROCm/TheRock" - - - name: Run setup test environment workflow - uses: './.github/actions/setup_test_environment' - with: - ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} - OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} - VENV_DIR: ${{ env.VENV_DIR }} - FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} - PLATFORM: ${{ inputs.platform }} - IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} - - - name: Test - timeout-minutes: ${{ matrix.components.timeout_minutes }} - run: | - if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi - ${{ matrix.components.test_script }} + uses: './.github/workflows/therock-test-component.yml' + with: + artifact_run_id: ${{ github.run_id }} + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: ${{ inputs.platform }} + component: ${{ toJSON(matrix.components) }} diff --git a/CHANGELOG.md b/CHANGELOG.md index f21795012d..fe1e7ef345 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. +* Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data diff --git a/CMakeLists.txt b/CMakeLists.txt index 26d91fe6d8..88b8f05200 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -339,6 +339,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) +option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -352,6 +353,11 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if (ENABLE_JSON_DUMP) + add_compile_definitions(CK_ENABLE_JSON_DUMP) + message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Dockerfile b/Dockerfile index 6f5cd0115d..07327442fe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,23 @@ + FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.4.1 +ARG ROCMVERSION=7.0.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive # Add rocm repository RUN set -xe && \ - apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ - curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN if [ "$ROCMVERSION" != "6.5" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ - fi - -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ - amdgpu-install -y --usecase=rocm --no-dkms +RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ + apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ + apt update && \ + apt install python3-setuptools python3-wheel -y && \ + apt install rocm-dev -y ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -45,7 +41,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libelf-dev \ libnuma-dev \ libpthread-stubs0-dev \ - llvm-amdgpu \ mpich \ net-tools \ pkg-config \ @@ -61,17 +56,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- zip \ libzstd-dev \ openssh-server \ - clang-format-12 \ clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ rm -rf amdgpu-install* && \ -# Remove unnecessary rocm components that take a lot of space - apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt - #Install latest ccache -RUN git clone https://github.com/ccache/ccache.git && \ + git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 0306057e45..47bd8294b6 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index efe08a7d41..d494b0bf49 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -53,7 +53,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -476,7 +476,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 20, unit: 'HOURS') { @@ -538,7 +538,7 @@ def Build_CK(Map conf=[:]){ def image def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -728,7 +728,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -836,7 +836,7 @@ def run_aiter_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -859,6 +859,7 @@ def run_aiter_tests(Map conf=[:]){ sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -894,7 +895,7 @@ def run_pytorch_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -930,7 +931,8 @@ def run_pytorch_tests(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true + 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true @@ -957,8 +959,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.4.1', - description: 'Specify which ROCM version to use: 6.4.1 (default).') + defaultValue: '7.0.1', + description: 'Specify which ROCM version to use: 7.0.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -1037,8 +1039,8 @@ pipeline { description: "Build CK and run tests on gfx942 (default: ON)") booleanParam( name: "BUILD_GFX950", - defaultValue: false, - description: "Build CK and run tests on gfx950 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx950 (default: ON)") booleanParam( name: "BUILD_GFX10", defaultValue: true, @@ -1290,7 +1292,7 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_PREFIX_PATH=/opt/rocm ../codegen && \ make -j64 check""" } steps{ @@ -1350,7 +1352,6 @@ pipeline { } agent{ label rocmnode("gfx950") } environment{ - def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0" setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \ make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ @@ -1358,7 +1359,7 @@ pipeline { example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """ } steps{ - buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() } } @@ -1566,7 +1567,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1631,7 +1632,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB}:ck_ub24.04_rocm7.0.1") } cleanWs() } @@ -1657,13 +1658,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx1101") + stage("Build CK and run Tests on gfx11") { when { beforeAgent true expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx1101") } + agent{ label 'miopen && (gfx1101 || gfx1100)' } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2b2e6e2949..80429a781b 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -12,6 +12,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -27,7 +28,7 @@ add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers) +target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index d149fd88f1..d5c42558c4 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; #else - < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; using ADataType = float; using BDataType = float; using CDataType = float; @@ -185,7 +185,6 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -209,8 +208,7 @@ int main(int argc, char* argv[]) return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 08e2b8c15f..7fb0c1e812 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -2,7 +2,6 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/validation_common.hpp" // use macro to minimize code change #ifndef EXAMPLE_WITH_COMPUTE_DATATYPE @@ -29,11 +28,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -59,17 +58,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - try - { - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - } - catch(const std::runtime_error& e) - { - std::cerr << "Error: " << e.what() << std::endl; - return false; - } - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index bffa2e5640..992e7c19c8 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -174,6 +174,9 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + const auto StrideD = std::is_same::value + ? d_m_n.mDesc.GetStrides()[0] + : d_m_n.mDesc.GetStrides()[1]; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; @@ -221,7 +224,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{0}, + std::array{static_cast(StrideD)}, StrideE, a_element_op, b_element_op, diff --git a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc index cb0271c81f..796a5d3e9b 100644 --- a/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc +++ b/example/04_gemm_add_add_fastgelu/run_gemm_add_add_fastgelu_example.inc @@ -7,7 +7,9 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC #endif using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + ProblemSize ps = + problem_size; // make mutable copy because default stride values of 0 need to be updated + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -41,6 +43,30 @@ bool run_gemm_add_add_fastgelu(const ProblemSize& problem_size, const ExecutionC std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + // If any user-provided leading stride <= 0, replace it with the one determined by the + // created tensor descriptor. For RowMajor the leading stride is index 0, for ColMajor index 1. + auto fetch_leading_stride = [](const auto& tensor, auto layout_tag) -> int { + if constexpr(std::is_same_v) + { + return static_cast(tensor.GetStrides()[0]); + } + else + { + return static_cast(tensor.GetStrides()[1]); + } + }; + + if(StrideA <= 0) + StrideA = fetch_leading_stride(a_m_k, ALayout{}); + if(StrideB <= 0) + StrideB = fetch_leading_stride(b_k_n, BLayout{}); + if(StrideD0 <= 0) + StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{}); + if(StrideD1 <= 0) + StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{}); + if(StrideE <= 0) + StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{}); + switch(config.init_method) { case 0: break; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 036f288d0a..7142521c55 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); problem_size = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index c4e7068499..4b290d02a2 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -23,7 +23,7 @@ using RsGlobalReduceOp = static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off template diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 3ce08fd2af..abbf1b29f7 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -78,12 +78,12 @@ bool pool_test(bool do_verification, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 2585072dfe..5291f5ce69 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -115,12 +115,14 @@ int main() if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1_uz})); + std::vector({stride, 1_uz}), + layout); } else { return HostTensorDescriptor(std::vector({row, col}), - std::vector({1_uz, stride})); + std::vector({1_uz, stride}), + layout); } }; diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 13da444051..4a701e7792 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -137,11 +137,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); } }; diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index ce9f9b7032..ae5e3f36ad 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>; // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { - // temp disable on gfx11 & gfx12 - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + // temp disable on gfx11 + if(ck::is_gfx11_supported()) { return 0; } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 741512bf00..c93a2051d2 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -59,11 +59,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 3582bc5e33..ac34ed5b8a 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -137,11 +137,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 778be8ffd7..9939429a08 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -64,11 +64,13 @@ bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionCo if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index 420a7cf74f..4f4003809b 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1 using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -189,7 +191,7 @@ int run_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -173,7 +175,7 @@ int run_contraction_scale_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 using S = ck::Sequence; @@ -304,10 +307,10 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Bypass{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); ck::index_t M_ = ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); @@ -416,9 +419,9 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index f556be887f..c4cb7a13a2 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -300,11 +303,11 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; @@ -396,7 +399,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -345,7 +348,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + #include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc" int main(int argc, char* argv[]) diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 255a0cddaf..7a03e9cacf 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -110,11 +110,13 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc index 8ab47c2925..cea18459f4 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc @@ -62,17 +62,19 @@ int run(int argc, char* argv[]) std::vector b1_g_o_n_lengths{G, O, N}; #ifdef CK_MHA_USE_RCCR_LAYOUT std::vector b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N] + auto b1_layout = Row{}; #else std::vector b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O] + auto b1_layout = Col{}; #endif std::vector c_g_m_o_lengths{G, M, O}; std::vector c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O] - Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides); - Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides); - Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides); - Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides); - Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides); + Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides, Row{}); + Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides, Row{}); + Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides, b1_layout); + Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); + Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 1514fc48b3..aa2a6b3b42 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -111,12 +111,14 @@ int run(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); + std::vector({batch_stride, stride, 1}), + layout); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); + std::vector({batch_stride, 1, stride}), + layout); } }; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 2b02069e65..6175f0b5be 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +90,11 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Bypass{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Bypass{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Bypass{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index e0ccb6dad1..db13e3b963 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +92,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 0ad031cc71..1e4b52d4cf 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -113,11 +117,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -191,7 +214,7 @@ int run(int argc, char* argv[]) head_num * 2 * head_dim, head_dim, 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] - Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides, Bypass{}); // merge kv into a packed pointer send to device b0_gs_ns_ks.ForEach( [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index c693995140..874d987a1d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -63,6 +67,19 @@ int run(int argc, char* argv[]) std::size_t flop = 0, num_byte = 0; + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; std::cout << "group count " << group_count << ". printing first 4 groups\n"; for(std::size_t i = 0; i < group_count; i++) { @@ -113,10 +130,14 @@ int run(int argc, char* argv[]) {}}); // acc1_biases_gs_ms_os_strides // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks(f_host_tensor_descriptor( + b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns(f_host_tensor_descriptor( + b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_device_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); int Batch = G0 * G1; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; @@ -252,7 +273,8 @@ int run(int argc, char* argv[]) Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_host_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); // permute a_gs_ms_ks.ForEach([&](auto& self, auto idx) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 7ac29f33ca..1c2a26d916 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index fb9b1b0bd7..76f3ee756c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 2cb69380e5..86754927ed 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -108,11 +112,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -186,7 +209,7 @@ int run(int argc, char* argv[]) head_num * 3 * head_dim, head_dim, 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] - Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides, Bypass{}); // merge qkv into a packed pointer send to device a_gs_ms_ks.ForEach( [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 904ff761fd..4934f74393 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -321,11 +321,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0f0b120cbc..80d56cd781 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -206,7 +206,8 @@ int run_grouped_conv_bwd_data_bias_relu_example(int argc, char* argv[]) 1, // c 0, // hi 0 // wi - }); + }, + ctc::GNCHW{}); // input image: GNHWC const auto in_g_n_c_wis_desc = diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 30e0791ebf..3c089688cf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -214,7 +214,8 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_ 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto requant_scale_g_k_desc = bias_g_k_desc; diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 32fd435e00..ed7886e76b 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -201,7 +201,8 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 362d90b4c1..12fdf425bf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -203,7 +203,8 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme 1, // k 0, // ho 0 // wo - }); + }, + RequantScaleLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index ebba88cf41..b5e9686260 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1, 2> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 2> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -134,7 +136,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 9e92543252..2d689648f2 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple @@ -72,9 +74,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[3])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -117,7 +119,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 88c23b5f40..6e70a306d3 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -23,6 +23,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -137,7 +139,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 1185b5a3ca..632d88e88a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -128,7 +131,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 28a3dbc44c..bd54f1c19c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; @@ -139,7 +141,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index 14d1d96165..9621d591a9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -127,7 +130,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 2583f1cb5e..be4014f636 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -78,13 +81,13 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 3> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 3> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; Tensor& a2 = as[2]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; float gamma = 4.f; @@ -149,7 +152,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc index e1b2bccfe1..24807aeeb3 100644 --- a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc +++ b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc @@ -1,22 +1,30 @@ #pragma once +#include bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + ProblemSize ps = + problem_size; // make mutable copy because default stride values of 0 need to be updated + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index 1b24bd3bba..3e69caf51e 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -18,6 +18,10 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -220,12 +224,12 @@ int main(int argc, char* argv[]) std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Col{}); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides, Row{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 788f38ec52..ef64dd167d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -48,15 +48,16 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Pool3d_fwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = AddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(D0DataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 1; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..086a0f4834 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fastgelu_bf16_i8.cpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AElementOp = PassThrough; +using BElementOp = Multiply; +using CDEElementOp = FastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 2; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideE = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB, B1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{0, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 2; + constexpr ck::index_t NumDTensor = 0; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer()}, + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB, StrideB}, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + b_element_op(b_k_n(k, n), b0_k_n(k, n), b1_k_n(k, n)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp new file mode 100644 index 0000000000..32345d1263 --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_fp16.cpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + // this attribute controls the copy_function applying element_wise_op with + // pack4_data + constexpr const static bool is_pack4_invocable = true; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 256, + 128, + 32, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 64, 1, 4>, + S<8, 8, 8>>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 12: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp new file mode 100644 index 0000000000..00e2d7e33c --- /dev/null +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_wmma_multiply_bias_fastgelu_bf16_i8.cpp @@ -0,0 +1,296 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using D1Layout = D0Layout; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyAddFastGelu; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmma_CShuffleV3< + AsLayout, + BsLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<8, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 0, + S<8, 32, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<8, 8, 8>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 4096; + ck::index_t N = 768; + ck::index_t K = 6144; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideD = N; + ck::index_t StrideE = N; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumATensor = 1; + constexpr ck::index_t NumBTensor = 1; + constexpr ck::index_t NumDTensor = 2; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer()}, + std::array{b0_device_buf.GetDeviceBuffer()}, + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + std::array{StrideD, StrideD}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index 5220a4616e..405eac7df1 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{StrideD}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index b424fdaf45..50e670bdf3 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 03a74c04b7..50e1c21c8f 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -80,10 +80,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -119,23 +120,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) K, std::array{StrideA}, std::array{StrideB}, - std::array{0, StrideD}, + std::array{StrideB1, StrideD}, StrideE, a_element_op, b_element_op, @@ -261,7 +270,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n)); + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(m, n), d_m_n(m, n)); } } diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 90e14de59c..a9a30b4c27 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -160,12 +163,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -264,9 +267,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -299,7 +302,6 @@ int main(int argc, char* argv[]) auto ref_op = ReferenceOpInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - Tensor empty_tensor(std::vector{}, std::vector{}); auto ref_argument = ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index ec1b2d6018..4f7414abfa 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -140,12 +143,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); - Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -246,9 +249,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -266,7 +269,7 @@ int main(int argc, char* argv[]) } } - Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) { diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp index 2afe01f02d..0a802ee27d 100644 --- a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -130,11 +130,12 @@ bool run_grouped_conv(bool do_verification, // Fill other lenghts than G,K with 1 and strides with 0 bias_g_k_lengths.fill(1); bias_g_k_strides.fill(0); - bias_g_k_lengths[0] = G; - bias_g_k_lengths[2] = K; - bias_g_k_strides[0] = K; // stride to G - bias_g_k_strides[2] = 1; // stride to K - const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = + HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides, BiasLayout{}); // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) Tensor in(in_g_n_c_wis_desc); diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc index dc2bdc18f0..41c8c42bac 100644 --- a/example/64_fpAintB_gemm/run_gemm_example.inc +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -28,7 +28,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] - Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + Tensor scale_k_n( + HostTensorDescriptor({K, N}, {0, 1_uz}, ck::tensor_layout::BypassLayoutVerification())); switch(config.init_method) { diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp index 53963fc514..8b8cee9e52 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -241,6 +241,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -285,8 +307,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -308,7 +328,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 7a2d0153d9..8da49ef85d 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -162,6 +162,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -216,7 +238,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{StrideD, StrideD}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index fe1eca51b0..3ee4955ae4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -251,6 +251,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -295,8 +317,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -318,7 +338,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 52ba3416a0..72ea7f1cb6 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -287,15 +287,18 @@ int main(int argc, char* argv[]) } } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; @@ -422,7 +425,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( - {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; @@ -463,7 +467,7 @@ int main(int argc, char* argv[]) Tensor b_e_n_k({experts, K, N * 2}); e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); // handle scale before ref. for(int t = 0; t < tokens; ++t) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 92a0cd9e5c..5e306ac6dd 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -264,15 +264,18 @@ int main(int argc, char* argv[]) } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; @@ -488,7 +491,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d0_t_n( - HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 354957c0d1..cc42c4b815 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -292,17 +292,19 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k(HostTensorDescriptor( {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 6ca7d67f53..29e758f9d4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -29,8 +29,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = I4; @@ -239,10 +240,10 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc index 82ac0a15e1..b08d12de86 100644 --- a/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc +++ b/example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc @@ -95,25 +95,26 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) exit(0); } + using DefaultLayout = ck::tensor_layout::gemm::RowMajor; // For Real Part of Complex Tensor - Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // For Imaginary Part of Complex Tensor - Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // Intermediate E tensor Definition - Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; @@ -349,8 +350,10 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { // Real Part Verification - Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_re1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_img1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); auto ref_argument_img = ref_op.MakeArgument( a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index aaf0cb3891..69c0d6558f 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -269,10 +269,12 @@ int main(int argc, char* argv[]) Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -281,12 +283,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -480,7 +483,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -278,12 +280,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -477,7 +480,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -310,12 +313,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -506,7 +510,7 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 829bf9af24..5bb6454d2a 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -268,16 +268,18 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index efbd0f0c03..333f8a3d52 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -303,16 +303,18 @@ int main(int argc, char* argv[]) expert_ids.savetxt("expert_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -321,7 +323,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 3d79f2f6d3..b8ca26193d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -47,7 +47,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 - --optdim 32,64,128,256 + --optdim 32,64,96,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) @@ -169,6 +169,10 @@ if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + target_compile_options(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 7f55d7412f..2b872cb9b5 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -36,6 +36,13 @@ args: total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) @@ -76,11 +83,20 @@ args: -repeat number of iterations to benchmark the kernel (default:20) -json 0: No Json, 1: Dump Results in Json format (default:0) -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -128,6 +144,15 @@ Note FA use bottom-right by default to express swa case, here we require you exp ### dropout TBD +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8f710050b1..bd6a9044e9 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout}; @@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, - {F_dvpad}>>; + ({F_dvpad} > 0)>>; using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; @@ -220,9 +214,9 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) && ({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{ - using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; r = fmha_bwd_>(s, a); return r; }} @@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_dpad : str # - F_dvpad : str # + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] F_bias : str # F_dbias : str # F_dropout : str # @@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = DROPOUT_MAP[self.F_dropout], @@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel: def name(self) -> str: def pad_name() -> str: n = '' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' if n != '' : n = 'p' + n return n pn = pad_name() @@ -380,6 +374,7 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize] return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -621,8 +616,8 @@ class FmhaBwdApiTrait: dbias : str dropout : str spad1d : str # spad for 1d kernels (dot/convert) - dpad : str - dvpad : str + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] deterministic : str mask_impl : str tr_load : str @@ -651,13 +646,13 @@ class FmhaBwdApiTrait: @property def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' @property def extra_cond(self) -> str: @@ -677,8 +672,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dvpad = 't' if self.dvpad else 'f' return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: @@ -693,8 +689,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dpad = 't' if self.dpad else 'f' return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -720,7 +717,7 @@ class FmhaBwdApiPool: F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, F_convert_dq_bn0=trait.convert_dq_bn0) @@ -793,7 +790,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -804,8 +804,12 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - if tr_load == "t" and (dpad == "t" or dvpad == "t"): + if tr_load == "t": continue # tr_load cannot work with dpad or dvpad + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue if optdim_list != [-1]: if hdim not in optdim_list: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cfb96b7d53..da0c9ca931 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -259,11 +259,11 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 91cb9f55be..79fda6d564 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -107,7 +111,15 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -127,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -174,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, + seqlen_qpads, seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 569c98a458..7ddb65a2db 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -172,6 +183,8 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; }; struct RunConfig @@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args; + ck_tile::fmha_fwd_v3_args args{}; args.data_type = problem.data_type; args.batch = problem.batch; @@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..6cd1cd94fa 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -15,6 +15,10 @@ #include #include +struct FmhaBwdFp32 +{ +}; + struct FmhaBwdFp16 { }; @@ -26,6 +30,26 @@ struct FmhaBwdBf16 template struct FmhaBwdTypeConfig; +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using GemmDataType = float; + using BiasDataType = float; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = float; + using OGradDataType = float; + using QGradDataType = float; + using KGradDataType = float; + using VGradDataType = float; + using BiasGradDataType = float; +}; + template <> struct FmhaBwdTypeConfig { @@ -368,8 +392,8 @@ template {ck_tile::numeric::infinity()}(dq_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dk_host); ck_tile::FillConstant{ck_tile::numeric::infinity()}(dv_host); + ck_tile::FillConstant{ck_tile::numeric::infinity()}(dq_acc_host); dq_buf.ToDevice(dq_host.data()); dk_buf.ToDevice(dk_host.data()); dv_buf.ToDevice(dv_host.data()); + dq_acc_buf.ToDevice(dq_acc_host.data()); o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); - dq_buf.SetZero(); dbias_buf.SetZero(); - dq_acc_buf.SetZero(); + + // non-deterministic kernels use atomic add to write dq + // Some block may be skipped with causal mask and dq are not set to zeros + // In these cases thus we need to zero out it first + if(!deterministic || mask.type == mask_enum::no_mask) + dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; fmha_bwd(fmha_traits, fmha_args, stream_config_v); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c41e48e6aa..f5dd42a6bd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -162,11 +162,20 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; + // Optional cumulative sequence length arrays + // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + // Group mode: seqstart_padded_* provide physical starts including PAD (optional) + const void* seqstart_padded_q_ptr = nullptr; // [batch+1] + const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -554,7 +563,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.seqstart_padded_q_ptr, + args.seqstart_padded_k_ptr); } else { // create batch mode kernel arguments @@ -600,7 +611,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 43f484fe14..5c6c7d923a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -151,7 +151,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -299,6 +302,24 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) + const bool has_group_padding = + (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || + (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); + const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || + !kv_eff_lens_per_batch.empty())); + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + if((using_appendkv || using_pagedkv || using_splitkv) && + (has_group_padding || has_batch_efflens)) + { + std::cerr << "Padding (physical or effective lengths) is not supported with " + "appendkv/splitkv/pagedkv pipelines" + << std::endl; + return fwd_result::invalid_args; + } + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = generate_missing_seqlens(mode, batch, @@ -362,6 +383,44 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -445,8 +504,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck_tile::index_t shape_seqlen_q = + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -504,7 +570,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -602,6 +668,16 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -693,8 +769,14 @@ fwd_result fmha_fwd_run(mode_enum mode, vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() - : seqstart_k_with_padding_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -747,6 +829,54 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cout << ", cache_batch_idx:" << use_cache_batch_idx; } #endif + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_efflens) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + std::cout << std::flush; const auto init_traits = [&](auto& traits) { @@ -830,8 +960,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -846,8 +976,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -961,6 +1091,29 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } } else if constexpr(std::is_same_v>) { @@ -1167,15 +1320,29 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1538,8 +1705,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 10cb5149a4..4bd1d1a367 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,6 +56,11 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; + + // Optional batch-mode cumulative seqlen overrides (exclude PAD) + // If provided, they override per-batch effective lengths to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index e0fbad39a5..194675f962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,7 +158,9 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt); + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..31ad800039 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index b847e85398..a3f7d68eb3 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,3 +23,20 @@ done done done done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt index 1497d491bb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -1,31 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt index 90c5e2b7fb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -1,4 +0,0 @@ -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index e7babd2744..5c2a5a4b3d 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -34,15 +34,15 @@ function print_log_header(){ } #run verification tests -example/ck_tile/01_fmha/script/smoke_test_fwd.sh -example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" print_log_header $fmha_fwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" print_log_header $fmha_bwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 3b59505ff0..cd51dde2d4 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -6,7 +6,7 @@ SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) EXE_NAME=tile_example_fmha_bwd EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 -GPU_arch=$GPU_arch +GPU_arch=${GPU_arch:-""} if [ -z "$GPU_arch" ] ; then GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') fi @@ -31,7 +31,17 @@ run_exe() { set -ex } +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + set -x +# main tests for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 256 ; do @@ -40,21 +50,21 @@ for bias in "n" "a" ; do for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done -run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS - -done -done -done -done -done -done -done +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index afd0c728c6..fca6b8d0cd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -137,9 +137,118 @@ run_fp16_appendkv_tests() { done ; done ; done } +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + set -x run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 4f3b173c55..bbfb2df006 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,10 +1,12 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) +add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 94481fa7b7..0821065098 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -1,140 +1,8 @@ -# Grouped Gemm - -Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation. - ## Quick Tour for New Users The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads. -Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function. - -### Key Arguments -The example takes several arguments including `group_count`, `repeat`, and `warmup`: -- `group_count`: the number of GEMM operations in the group -- `repeat`: the number of times to repeat the kernel for benchmarking -- `warmup`: the number of iterations before the actual kernel run time measure - -```cpp -// Example -const int group_count = arg_parser.get_int("group_count"); -const int repeat = arg_parser.get_int("repeat"); -const int warmup = arg_parser.get_int("warmup"); -``` -In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`. - -```cpp -// Example -std::vector Ms = arg_parser.get_int_vec("Ms"); -std::vector Ns = arg_parser.get_int_vec("Ns"); -std::vector Ks = arg_parser.get_int_vec("Ks"); -std::vector stride_As = arg_parser.get_int_vec("stride_As"); -std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); -std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); -``` -Where: -- `Ms` is the M dimension of each GEMM. -- `Ns` is the N dimension of each GEMM. -- `Ks` is the K dimension of each GEMM. -- `stride_As` is the stride values for matrix A. -- `stride_Bs` is the stride values for matrix B. -- `stride_Cs` is the stride values for matrix C. - -### HostTensor and Device Memory Buffers (for CPU and GPU) -Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations. -The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where: -- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation. -- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation. - -#### HostTensor Buffers (for CPU) -In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors: -```cpp -// Example -std::vector> a_m_k_tensors; -std::vector> b_k_n_tensors; -std::vector> c_m_n_tensors; -``` -Where: -- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation. -- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`). -- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`). - -The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`. - -#### Device Memory Buffers (for GPU) -Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation.. -```cpp -// Example -std::vector> a_m_k_dev_buf; -std::vector> b_k_n_dev_buf; -std::vector> c_m_n_dev_buf; -``` -Where: -- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU. -- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU. -- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU. - -## Prepare data -In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor. - -Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order. - -```cpp -// Example, API -template -auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant) { - // code -} -``` - -Where: -- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false). -- `row` is the number of rows in the matrix. -- `col` is the number of columns in the matrix. -- `stride` is the current stride (the distance between consecutive elements in memory). -- `bool_constant` is a tag type that helps in differentiating behavior at compile-time. - -Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout. - -```cpp -// Example for tensor A -ck_tile::HostTensor(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))) -``` - -After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`. - -The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as: -```cpp -struct GroupedGemmHostArgs -{ - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; -}; -``` -The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`: -```cpp -// Example, API -template -float invoke_gemm(int n_warmup, - int n_repeat, - int group_count, - const std::vector& args) -``` -`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace: - -```cpp -// Example -ck_tile::DeviceMem gemm_workspace; -gemm_workspace.Realloc(GetWorkspaceSize(args)); -``` - -### Advanced Features: Preshuffle and Persistence +### Preshuffle and Persistence The grouped GEMM examples include two advanced optimization features: @@ -153,17 +21,17 @@ Persistence mode is a GPU optimization where thread blocks remain active on the - **Usage**: `invoke_gemm` enables persistence - **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes -Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. +#### Multi-D Operations +Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output. -Finally the arguments are passed to group_gemm and the kernel is launched. -```cpp -// API -template -float grouped_gemm(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* kargs_ptr) -``` -All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched. +- **Implementation**: Available in `grouped_gemm_multi_d.cpp` +- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result) +- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors +- **Data Types**: Supports fp16 +- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call +- **Build Target**: `make tile_example_grouped_gemm_multi_d -j` + +Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. ## Build ``` @@ -175,10 +43,13 @@ mkdir build && cd build make tile_example_grouped_gemm -j # The preshuffle example make tile_example_grouped_gemm_preshuffle -j +# The multi-D operations example +make tile_example_grouped_gemm_multi_d -j # The quant grouped gemm fp8 example make tile_example_quant_grouped_gemm -j ``` -This will result in an executable `build/bin/tile_example_grouped_gemm` +This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. + ## example ``` @@ -213,4 +84,4 @@ K[i] = 512 + 384 * i stride_A[i] = K[i] stride_B[i] = K[i] stride_C[i] = N[i] -``` +``` \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 6493a542ba..10d7befc06 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,7 +9,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 @@ -296,7 +295,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>; std::pair create_args(int argc, char* argv[]) { @@ -325,7 +324,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } template diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp new file mode 100644 index 0000000000..409eda8de4 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm_multi_d.hpp" + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +#include "run_grouped_gemm_multi_d_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_grouped_gemm_multi_d_example(argc, argv); +#else + return !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv); +#endif +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp new file mode 100644 index 0000000000..f7727d854c --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +using ADataType = ck_tile::half_t; +using BDataType = ck_tile::half_t; +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; +using EDataType = ck_tile::half_t; +using DsDataType = ck_tile::tuple; +using AccDataType = float; + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet + static constexpr bool Persistent = false; // currently persistent == true is not supported yet + static constexpr bool DoubleSmemBuffer = + false; // currently double smem buffer == true is not supported yet +}; + +struct GemmConfigMemory : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; +struct GemmConfigV4 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>; + +std::pair create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Ds", "", "Tensor Ds strides - it is empty by default.") + .insert("stride_Es", "", "Tensor E strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default.") + .insert("e_layout", "R", "E tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "fp16", "data type. fp16") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>); +} + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 17e0ee5342..10d317a2c7 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -183,12 +183,24 @@ int run_grouped_gemm_example_with_layouts(int argc, if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; + + // Clear existing (invalid) data before adding defaults + Ms.clear(); + Ns.clear(); + Ks.clear(); + stride_As.clear(); + stride_Bs.clear(); + stride_Cs.clear(); + stride_AQs.clear(); + stride_BQs.clear(); + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); Ns.push_back(256 + 512 * i); Ks.push_back(512 + 128 * i); + // Let get_default_stride calculate based on layout stride_As.push_back(0); stride_Bs.push_back(0); stride_Cs.push_back(0); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 1cd2212994..f822c7d8a7 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -88,7 +88,7 @@ float invoke_gemm(int n_warmup, // The contents of the memory pointed to by `kargs_ptr` pointer could be // written by e.g. another kernel from earlier stage. - std::vector kargs; + std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) @@ -109,7 +109,7 @@ float invoke_gemm(int n_warmup, const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, kargs.data(), - kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>), hipMemcpyHostToDevice, stream.stream_id_)); ave_time = grouped_gemm_tileloopGetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + gemm_descs.push_back({p_a, + p_b, + {/*ds_ptr*/}, + p_c, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + {/*stride_Ds*/}, + stride_Cs[i]}); } float ave_time = invoke_gemm + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_multi_d( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + (void)group_count; + // not supported yet + throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet"); + } + return ave_time; +} + +template +int run_grouped_gemm_multi_d_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + using CDElementWise = MultiplyMultiply; + using DsLayout = ck_tile::tuple; + + auto valid_input_data = [&](int group_count, const auto&... args) { + return !(args.empty() || ...) && group_count == (args.size() == ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_D0 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_D1 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_Es = arg_parser.get_int_vec("stride_Es"); + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_D0, stride_D1, stride_Es)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, " + "896, 1280..), stride_As (Ks), stride_Bs (Ks), stride_D0 (Ns), stride_D1 " + "(Ns), stride_Es (Ns)" + << std::endl; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 /* + 256 * i */); + Ns.push_back(256 /* + 512 * i */); + Ks.push_back(64 /* + 384 * i */); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + stride_Es.push_back(Ns[i]); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> d0_m_n_tensors; + std::vector> d1_m_n_tensors; + std::vector> e_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + d0_m_n_tensors.reserve(group_count); + d1_m_n_tensors.reserve(group_count); + e_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> d0_m_n_dev_buf; + std::vector> d1_m_n_dev_buf; + std::vector> e_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + d0_m_n_dev_buf.reserve(group_count); + d1_m_n_dev_buf.reserve(group_count); + e_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + + stride_D0[i] = ck_tile::get_default_stride(M, N, stride_D0[i], is_row_major(d0_layout)); + stride_D1[i] = ck_tile::get_default_stride(M, N, stride_D1[i], is_row_major(d1_layout)); + + stride_Es[i] = ck_tile::get_default_stride(M, N, stride_Es[i], is_row_major(e_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + + d0_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D0[i], is_row_major(d0_layout)))); + d1_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D1[i], is_row_major(d1_layout)))); + + e_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Es[i], is_row_major(e_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " d0_m_n: " << d0_m_n_tensors[i].mDesc + << " d1_m_n: " << d1_m_n_tensors[i].mDesc << " e_m_n: " << e_m_n_tensors[i].mDesc + << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{2.f, -2.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{2.f, -2.f}(d1_m_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); + + b_k_n_dev_buf.push_back(std::make_unique(b_k_n_tensors[i])); + + d0_m_n_dev_buf.push_back(std::make_unique(d0_m_n_tensors[i])); + d1_m_n_dev_buf.push_back(std::make_unique(d1_m_n_tensors[i])); + e_m_n_dev_buf.push_back(std::make_unique(e_m_n_tensors[i])); + + e_m_n_dev_buf[i]->SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer(); + + std::array ds_ptr_buf = { + d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()}; + std::array stridesDs = {stride_D0[i], stride_D1[i]}; + + gemm_descs.push_back({p_a, + p_b, + ds_ptr_buf, + p_e, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + stridesDs, + stride_Es[i]}); + } + + float ave_time = invoke_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name{"Grouped Gemm Multiple-D"}; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + flop += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + }); + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(EDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + std::vector> e_m_n_host_refs; + e_m_n_host_refs.reserve(group_count); + + // copy e_m_n_tensors result from device to host and initialize host tensors to zero + for(int i = 0; i < group_count; i++) + { + e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + e_m_n_host_refs.push_back(ck_tile::HostTensor( + host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], is_row_major(e_layout)))); + + e_m_n_host_refs[i].SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tensors[i], + b_k_n_tensors[i], + {d0_m_n_tensors[i], d1_m_n_tensors[i]}, + e_m_n_host_refs[i]); + std::cout << "e_m_n_host_refs[i]: " << std::endl; + e_m_n_host_refs[i].print_first_n(std::cout, 10); + std::cout << std::endl; + std::cout << "e_m_n_tensors[i]: " << std::endl; + e_m_n_tensors[i].print_first_n(std::cout, 10); + std::cout << std::endl; + + const float max_accumulated_value = + *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); + + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + + pass &= + ck_tile::check_err(e_m_n_tensors[i], + e_m_n_host_refs[i], + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_grouped_gemm_multi_d_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string ds_layout = arg_parser.get_str("ds_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_grouped_gemm_multi_d_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp index 4f9362beb2..fa914a7119 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -11,199 +11,14 @@ #include "ck_tile/host.hpp" #include "grouped_convolution_utils.hpp" - -template , - typename DsLayout = ck_tile::tuple<>, - typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, - const ck_tile::stream_config& s) -{ - constexpr int kBlockPerCu = 1; - - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 32; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; - - constexpr ck_tile::index_t VectorSizeA = 1; - constexpr ck_tile::index_t VectorSizeB = 1; - constexpr ck_tile::index_t VectorSizeC = 8; - - // Implicit GEMM Traits - using CodegenShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType = ck_tile::GroupedConvTraits; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - CodegenShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - true, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using ConvEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << '\n' - << "Vector size A: " << CodegenPipeline::GetVectorSizeA() - << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - if(args.k_batch == 1) - { - return Run(ck_tile::integral_constant{}); - } - else - { - return Run(ck_tile::integral_constant{}); - } -} - +#include "grouped_convolution_backward_data_invoker.hpp" #include "run_grouped_convolution_bwd_data_example.inc" -template -int run_grouped_conv_bwd_data_example_prec_type( - std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) -{ - using NWGC = ck_tile::tensor_layout::convolution::NWGC; - using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; - using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; - - using GKXC = ck_tile::tensor_layout::convolution::GKXC; - using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; - using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; - - using NWGK = ck_tile::tensor_layout::convolution::NWGK; - using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; - using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; - - if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NWGC{}, GKXC{}, NWGK{}); - } - else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); - } - else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") - { - return run_grouped_conv_bwd_data_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); - } - else - { - throw std::runtime_error("Unsupported memory layout!"); - } -} - template int run_grouped_conv_bwd_data_example(int argc, char* argv[]) { + using Invoker = GroupedConvolutionBackwardDataInvoker; + auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; @@ -215,12 +30,16 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[]) if(data_type == "fp16") { - return run_grouped_conv_bwd_data_example_prec_type( + return run_grouped_conv_bwd_data_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else if(data_type == "bf16") { - return run_grouped_conv_bwd_data_example_prec_type( + return run_grouped_conv_bwd_data_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp new file mode 100644 index 0000000000..1b3d45427d --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionBackwardDataInvoker +{ + + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_bwd_data(const ck_tile::GroupedConvBwdDataHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 1; + constexpr ck_tile::index_t VectorSizeB = 1; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsBwdData, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } + } +}; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index cebfa90579..4cddbae3ab 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -11,190 +11,14 @@ #include "ck_tile/host.hpp" #include "grouped_convolution_utils.hpp" - -template , - typename DsLayout = ck_tile::tuple<>, - typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) -{ - constexpr int kBlockPerCu = 1; - - constexpr ck_tile::index_t M_Tile = 64; - constexpr ck_tile::index_t N_Tile = 64; - constexpr ck_tile::index_t K_Tile = 64; - - constexpr ck_tile::index_t M_Warp = 2; - constexpr ck_tile::index_t N_Warp = 2; - constexpr ck_tile::index_t K_Warp = 1; - - constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; - - constexpr ck_tile::index_t VectorSizeA = 8; - constexpr ck_tile::index_t VectorSizeB = 8; - constexpr ck_tile::index_t VectorSizeC = 8; - - // Implicit GEMM Traits - using CodegenShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using GroupedConvTraitsType = ck_tile::GroupedConvTraits; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - CodegenShape, - typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - true, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; - using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - - using ConvEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << '\n' - << "Vector size A: " << CodegenPipeline::GetVectorSizeA() - << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - - return Run(ck_tile::integral_constant{}); -} - +#include "grouped_convolution_forward_invoker.hpp" #include "run_grouped_convolution_fwd_example.inc" -template -int run_grouped_conv_fwd_example_prec_type( - std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) -{ - using NWGC = ck_tile::tensor_layout::convolution::NWGC; - using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; - using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; - - using GKXC = ck_tile::tensor_layout::convolution::GKXC; - using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; - using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; - - using NWGK = ck_tile::tensor_layout::convolution::NWGK; - using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; - using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; - - if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NWGC{}, GKXC{}, NWGK{}); - } - else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); - } - else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC") - { - return run_grouped_conv_fwd_example_with_layouts{}, - GemmWarpConfig, - InPrecType, - WeiPrecType, - OutPrecType>( - argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); - } - else - { - throw std::runtime_error("Unsupported memory layout!"); - } -} - template int run_grouped_conv_fwd_example(int argc, char* argv[]) { + using Invoker = GroupedConvolutionForwardInvoker; + auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; @@ -206,12 +30,12 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) if(data_type == "fp16") { - return run_grouped_conv_fwd_example_prec_type( + return run_grouped_conv_fwd_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else if(data_type == "bf16") { - return run_grouped_conv_fwd_example_prec_type( + return run_grouped_conv_fwd_example_prec_type( in_layout, wei_layout, out_layout, argc, argv); } else diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp new file mode 100644 index 0000000000..0b9879d247 --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "grouped_convolution_utils.hpp" + +struct GroupedConvolutionForwardInvoker +{ + template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> + static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = GemmWarpConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmWarpConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmWarpConfig::K_Warp_Tile; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + CodegenShape, + typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + true, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue>; + + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + return Run(ck_tile::integral_constant{}); + } +}; diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc index 8519daaac2..3d7635bf4f 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_data_example.inc @@ -4,6 +4,7 @@ template ( + float ave_time = Invoker::template grouped_conv_bwd_data( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -39,6 +40,7 @@ float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args, template +int run_grouped_conv_bwd_data_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_data_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc index c5ae92a0da..beb6005e19 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc @@ -4,6 +4,7 @@ template ( + float ave_time = Invoker::template grouped_conv_fwd( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = args.GetFlops(); @@ -39,6 +40,7 @@ float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, template +int run_grouped_conv_fwd_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_fwd_example_with_layouts{}, + GemmWarpConfig, + Invoker, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} diff --git a/example/ck_tile/21_elementwise/elementwise_example.cpp b/example/ck_tile/21_elementwise/elementwise_example.cpp index 94d3e70be1..e9fbeafde1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example.cpp @@ -211,7 +211,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp index ff7ec1517e..1b101c2e5f 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -157,7 +157,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp index 16e9832c07..7cdb5cc0d1 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -156,7 +156,9 @@ bool run(const ck_tile::ArgParser& arg_parser) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp index c5a08d910e..4e19cfd688 100644 --- a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -193,7 +193,9 @@ auto string_to_op(const std::string& op) int main(int argc, char* argv[]) { - auto [result, arg_parser] = create_args(argc, argv); + bool result = true; + ck_tile::ArgParser arg_parser; + std::tie(result, arg_parser) = create_args(argc, argv); if(!result) return -1; diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5783605f8d..7aee7fca28 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/config.h" +#include #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 08b3aba2b3..5da447125e 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,6 +15,151 @@ namespace ck { namespace utility { +template +struct RotatingMemWrapperMultiABD +{ + static constexpr index_t NumAs = AsDataType::Size(); + static constexpr index_t NumBs = BsDataType::Size(); + static constexpr index_t NumDs = DsDataType::Size(); + + using AsGridPointer = decltype(Argument::p_as_grid); + using BsGridPointer = decltype(Argument::p_bs_grid); + using DsGridPointer = decltype(Argument::p_ds_grid); + + RotatingMemWrapperMultiABD() = delete; + RotatingMemWrapperMultiABD(Argument& arg_, + std::size_t rotating_count_, + std::array size_as_, + std::array size_bs_, + std::array size_ds_) + : arg(arg_), + rotating_count(rotating_count_), + size_as(size_as_), + size_bs(size_bs_), + size_ds(size_ds_) + { + p_as_grids.push_back(arg.p_as_grid); + p_bs_grids.push_back(arg.p_bs_grid); + p_ds_grids.push_back(arg.p_ds_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + AsGridPointer as_buffer; + static_for<0, NumAs, 1>{}([&](auto j) { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_as_[j])); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + static_cast(p_as_grids[0][j]), + size_as_[j], + hipMemcpyDeviceToDevice)); + using ADataType = remove_cvref_t>; + + as_buffer(j) = static_cast(pADeviceBuf); + }); + p_as_grids.push_back(as_buffer); + } + + { + BsGridPointer bs_buffer; + static_for<0, NumBs, 1>{}([&](auto j) { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_bs_[j])); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + static_cast(p_bs_grids[0][j]), + size_bs_[j], + hipMemcpyDeviceToDevice)); + using BDataType = remove_cvref_t>; + + bs_buffer(j) = static_cast(pBDeviceBuf); + }); + p_bs_grids.push_back(bs_buffer); + } + + { + DsGridPointer ds_buffer; + static_for<0, NumDs, 1>{}([&](auto j) { + void* pDDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); + hip_check_error(hipMemcpy(static_cast(pDDeviceBuf), + static_cast(p_ds_grids[0][j]), + size_ds_[j], + hipMemcpyDeviceToDevice)); + + using DDataType = remove_cvref_t>; + + ds_buffer(j) = static_cast(pDDeviceBuf); + }); + + p_ds_grids.push_back(ds_buffer); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_as_grid = p_as_grids[idx]; + arg.p_bs_grid = p_bs_grids[idx]; + arg.p_ds_grid = p_ds_grids[idx]; + } + } + void Print() + { + std::cout << "RotatingMemWrapperMultiD: { size_a: {"; + static_for<0, NumAs, 1>{}( + [&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); }); + std::cout << "}, size_b: {"; + static_for<0, NumBs, 1>{}( + [&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); }); + std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapperMultiABD() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_as_grid = p_as_grids[0]; + arg.p_bs_grid = p_bs_grids[0]; + arg.p_ds_grid = p_ds_grids[0]; + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + static_for<0, NumAs, 1>{}([&](auto j) { + using ADataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_as_grids[i][j])))); + }); + + static_for<0, NumBs, 1>{}([&](auto j) { + using BDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_bs_grids[i][j])))); + }); + + static_for<0, NumDs, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_ds_grids[i][j])))); + }); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::array size_as = {0}; + std::array size_bs = {0}; + std::array size_ds = {0}; + std::vector p_as_grids; + std::vector p_bs_grids; + std::vector p_ds_grids; +}; + template struct RotatingMemWrapperMultiD { @@ -318,6 +463,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, // total_time += cur_time; // #endif +#if !defined(CK_USE_WMMA) if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; @@ -326,6 +472,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, static_cast(gemm_args.p_a_grid), static_cast(gemm_args.p_b_grid)); } +#endif } hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventSynchronize(stop)); diff --git a/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp index d4ceefb458..e8d33f4216 100644 --- a/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp +++ b/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp @@ -203,8 +203,11 @@ make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvPa } return transpose_host_tensor_descriptor_given_new2old( - HostTensorDescriptor(physical_lengths), - detail::get_layout_transpose_gnchw_to_old()); + // TBD: specify explicit conv layout rather than base one + HostTensorDescriptor(physical_lengths, + ck::tensor_layout::convolution::BaseConvolutionLayout{}), + detail::get_layout_transpose_gnchw_to_old(), + InLayout{}); } // make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX @@ -296,8 +299,10 @@ make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvPa } return transpose_host_tensor_descriptor_given_new2old( - HostTensorDescriptor(physical_lengths), - detail::get_layout_transpose_gnchw_to_old()); + HostTensorDescriptor(physical_lengths, + ck::tensor_layout::convolution::BaseConvolutionLayout{}), + detail::get_layout_transpose_gnchw_to_old(), + WeiLayout{}); } // make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW @@ -386,8 +391,10 @@ make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvP } return transpose_host_tensor_descriptor_given_new2old( - HostTensorDescriptor(physical_lengths), - detail::get_layout_transpose_gnchw_to_old()); + HostTensorDescriptor(physical_lengths, + ck::tensor_layout::convolution::BaseConvolutionLayout{}), + detail::get_layout_transpose_gnchw_to_old(), + OutLayout{}); } } // namespace conv diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index fb8f6e79dc..55505524e0 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -21,6 +21,8 @@ #include "ck/library/utility/ranges.hpp" #include "ck/library/utility/thread.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) { @@ -97,59 +99,455 @@ auto construct_f_unpack_args(F, T args) return construct_f_unpack_args_impl(args, std::make_index_sequence{}); } +/** + * @brief A descriptor class for host tensors that manages tensor dimensions, strides, and layout. + * + * The HostTensorDescriptor provides a comprehensive interface for describing multi-dimensional + * tensors with configurable layouts and automatic stride calculation capabilities. + * + * @section stride_handling Stride Handling + * + * The descriptor supports multiple stride specification modes: + * + * 1. **Explicit Strides**: When strides are provided explicitly, they are validated against + * the specified layout to ensure memory access patterns are correct. + * + * 2. **Auto-calculated Strides**: When strides are empty or all-zero, they are automatically + * calculated based on the tensor layout: + * - For RowMajor layout: rightmost dimension has stride 1, others calculated as cumulative + * products + * - For ColumnMajor layout: similar to RowMajor but with swapped stride positions for last two + * dimensions + * + * 3. **Partial Stride Specification**: For GEMM layouts, unknown strides (represented as 0 or + * negative values) in the last two dimensions can be auto-calculated while preserving higher + * dimension strides. + * + * 4. **Bypass**: When using `BypassLayoutVerification` layout, no stride calculation or validation + * is performed. That allows to pass in any arbitrary strides including 0. + * + * For more details see `CalculateStrides` method. + * + * @section layout_support Layout Support + * + * - **GEMM Layouts**: Supports RowMajor and ColumnMajor layouts with full validation + * - **Convolution Layouts**: Recognized but validation is not yet implemented + * - **Abstract Layouts**: BaseTensorLayout will attempt automatic layout detection for 2D tensors + * + * @section limitations Limitations + * + * 1. **Layout Detection**: Automatic layout detection only works reliably for 2D tensors. + * This is done mostly for legacy GEMM cases to avoid modifying many existing GEMM tests to pass + * RowMajor/ColumnMajor explicitly. Higher-dimensional tensors with BaseTensorLayout will throw + * validation errors. For more details see `HandleDefaultLayout` method. + * + * 2. **Stride Validation**: Only GEMM layouts (RowMajor/ColumnMajor) have full stride validation. + * Convolution layouts are accepted but not validated. For more details see `ValidateStrides`. + * + * 3. **GEMM Assumptions**: For tensors with more than 2 dimensions, GEMM layout validation + * assumes the last two dimensions represent the height-width pattern (e.g., BHW or BWH for + * batched GEMM). + * + * 4. **Negative Stride Handling**: Negative stride values are interpreted as "unknown" and + * converted to auto-calculated values only for supported layouts. + * + * @section thread_safety Thread Safety + * This class is not thread-safe. External synchronization is required for concurrent access. + * + * @section examples Usage Examples + * + * ```cpp + * // Auto-calculate strides for RowMajor layout + * HostTensorDescriptor desc1({4, 3}, ck::tensor_layout::gemm::RowMajor{}); + * + * // Explicit strides with validation + * HostTensorDescriptor desc2({4, 3}, {3, 1}, ck::tensor_layout::gemm::RowMajor{}); + * + * // Partial stride specification (auto-calculate unknown dimension) + * HostTensorDescriptor desc3({4, 3}, {0, 1}, ck::tensor_layout::gemm::RowMajor{}); + * ``` + */ struct HostTensorDescriptor { - HostTensorDescriptor() = default; + using BaseTensorLayout = ck::tensor_layout::BaseTensorLayout; + using DefaultLayout = BaseTensorLayout; - void CalculateStrides(); - - template >> - HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) + // Runtime tag describing which layout is picked when layout is not specified explicitly at + // construction time. + enum class ChosenLayout { - this->CalculateStrides(); + Original, + RowMajor, + ColumnMajor + }; + + // Master constructor + template + HostTensorDescriptor(std::vector lens, + std::vector strides, + const Layout& layout = DefaultLayout()) + : mLens(std::move(lens)), mStrides(std::move(strides)) + { + // To support legacy use cases, when layout is not passed in + const auto new_layout = HandleDefaultLayout(layout); + if(dbg) + { + std::cout << "Original Lens: ["; + LogRange(std::cout, mLens, ", ") << "] and Strides: ["; + LogRange(std::cout, mStrides, ", ") << "]" << std::endl; + std::cout << "Layout: " << layout << " --> " << new_layout << std::endl; + } + + // Handling the strides and validation based on the chosen layout + DispatchChosenLayout(new_layout, layout, [&](auto selected_layout) { + this->CalculateStrides(selected_layout); + this->ValidateStrides(selected_layout); + }); } - HostTensorDescriptor(const std::initializer_list& lens) - : mLens(lens.begin(), lens.end()) + HostTensorDescriptor() : HostTensorDescriptor({}, {}, DefaultLayout()){}; + + // Helper that invokes a callable with a concrete layout object whose type + // matches the chosen tag (so template code depending on the layout type + // can still leverage if constexpr branches). + template + void DispatchChosenLayout(ChosenLayout tag, const OrigLayout& orig, F&& f) const { - this->CalculateStrides(); + switch(tag) + { + case ChosenLayout::RowMajor: f(ck::tensor_layout::gemm::RowMajor{}); break; + case ChosenLayout::ColumnMajor: f(ck::tensor_layout::gemm::ColumnMajor{}); break; + case ChosenLayout::Original: + default: f(orig); break; + } + } + + template + ChosenLayout HandleDefaultLayout(const Layout&) + { + if constexpr(!std::is_same_v) + { + return ChosenLayout::Original; + } + else + { + if(mStrides.empty()) + { + // No strides provided -> assume RowMajor + return ChosenLayout::RowMajor; + } + + const auto rank = mLens.size(); + + if(rank > 2) + { + // Keep as-is - validation will warn/throw later + return ChosenLayout::Original; + } + + if(rank == 0) + { + // Keep as-is - validation will warn/throw later + return ChosenLayout::Original; + } + + if(rank == 1) + { + // Treat 1D tensor as RowMajor + return ChosenLayout::RowMajor; + } + + // rank == 2 + if(mStrides.size() == 2) + { + // RowMajor pattern (?, 1) + if(mStrides[1] == 1) + { + return ChosenLayout::RowMajor; + } + + // ColumnMajor pattern (1, ?) + if(mStrides[0] == 1) + { + return ChosenLayout::ColumnMajor; + } + } + + // Fallback: leave as-is + return ChosenLayout::Original; + } + } + + template + void CalculateStrides(const Layout& layout) + { + if constexpr(std::is_same_v) + return; + // This is a workaround if the original stride value is -1 (which means "unknown") has been + // passed in and casted to size_t (unsigned). + auto strides_int = AsInt(mStrides); + + // case of empty strides or all-zero: auto-calculate based on layout and tensor dimensions + if(mStrides.empty() || std::all_of(strides_int.begin(), strides_int.end(), [](int stride) { + return stride <= 0; + })) + { + + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + std::cerr << "Only RowMajor and ColumnMajor layouts are supported for empty " + "strides, got " + << layout << ". Will calculate strides as RowMajor." << std::endl; + } + + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + + if constexpr(std::is_same_v) + { + // swap the last two strides + if(mStrides.size() >= 2) + std::swap(mStrides[mStrides.size() - 1], mStrides[mStrides.size() - 2]); + } + } + // The other case is if one of the strides is unknown + // Currently, only GEMM RowMajor and ColumnMajor layouts are supported and only in the lower + // two dimensions, e.g. {..., 0, N} or {..., M, 0}. The higher dimensions are left + // untouched. + else if constexpr(std::is_same_v || + std::is_same_v) + { + auto rank = mStrides.size(); + if(mLens.size() >= 2 && rank >= 2) + { + const auto inner_idx = + std::is_same_v ? rank - 1 : rank - 2; + const auto outer_idx = inner_idx == rank - 1 ? rank - 2 : rank - 1; + if(mStrides[inner_idx] <= 0) + { + mStrides[inner_idx] = 1; + } + if(mStrides[outer_idx] <= 0) + { + mStrides[outer_idx] = mLens[inner_idx] * mStrides[inner_idx]; + } + } + } + } + + template + void ValidateStrides(const Layout& layout) const + { + if constexpr(std::is_same_v) + { + return; + } + + if(mLens.empty()) + { + throw std::runtime_error( + "HostTensorDescriptor::ValidateStrides: empty tensor dimensions is not allowed."); + } + + const int rank = mLens.size(); + if(rank == 1) // skip any 1D tensors + { + return; + } + + if constexpr(std::is_same_v) + { + // Any legacy code that doesn't pass layout to HostTensorDescriptor ctor will + // hit this case (unless it is a special case - see `HandleDefaultLayout`). + throw std::runtime_error("HostTensorDescriptor::ValidateStrides: Abstract tensor " + "layout BaseTensorLayout can't be verified. Pls " + "pass specific tensor layout to HostTensorDescriptor (or " + "ck::tensor_layout::BypassLayoutVerification)"); + } + + // GEMM cases + if constexpr(std::is_base_of_v) + { + if(mLens.size() != mStrides.size()) + { + std::ostringstream oss; + oss << "HostTensorDescriptor::ValidateStrides: mismatch between tensor rank and " + "size of strides: " + << *this; + throw std::runtime_error(oss.str()); + } + + // in GEMM, strides must be all positive or all zeros (auto-derived from tensor + // dimensions) + auto strides_int = AsInt(mStrides); + if(std::any_of( + strides_int.begin(), strides_int.end(), [](int stride) { return stride <= 0; })) + { + std::ostringstream oss; + oss << "Stride values must be positive or all-zeros (auto-derived from tensor " + "dimensions). Instead got "; + std::copy( + strides_int.begin(), strides_int.end(), std::ostream_iterator(oss, " ")); + throw std::runtime_error(oss.str()); + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + // The logic here assumes the GEMM with tensor of more than 2 dims, will always have + // HW dimesnsions as the inner ones e.g. batched GEMM is either BHW or BWH + const auto inner_idx = + std::is_same_v ? rank - 1 : rank - 2; + const auto outer_idx = inner_idx == rank - 1 ? rank - 2 : rank - 1; + + if(mStrides[outer_idx] < mLens[inner_idx] * mStrides[inner_idx]) + { + std::ostringstream oss; + oss << "Invalid strides for " << layout << ": " << *this; + throw std::runtime_error(oss.str()); + } + + // For higher dimensions, validate strides assuming RowMajor + for(int i = 1; i < rank - 2; ++i) + { + if(mStrides[i - 1] < mStrides[i] * mLens[i]) + { + std::ostringstream oss; + oss << "Invalid strides for higher dimensions in " << layout << ": " + << *this; + throw std::runtime_error(oss.str()); + } + } + } + else + { + std::ostringstream oss; + oss << "Error: Unsupported GEMM layout: " << layout; + throw std::runtime_error(oss.str()); + } + } + // Convolution cases + else if constexpr(std::is_base_of_v) + { + // TBD: implement verification for Conv layouts + // For now, just print warning and return + std::cerr << "Warning: Tensor layout verification for ck::tensor_layout::convolution " + "layouts is not supported yet. Skipping..." + << std::endl; + return; + } + else + { + std::ostringstream oss; + oss << "Error: Tensor layout verification for " << layout << " is not supported yet."; + throw std::runtime_error(oss.str()); + } + } + + template && + std::is_convertible_v>> + HostTensorDescriptor(const std::initializer_list& lens, const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), {}, layout) + { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; + } + + template >> + HostTensorDescriptor(const std::initializer_list& lens, + const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), {}, layout) + { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; } template , std::size_t> || - std::is_convertible_v, ck::long_index_t>>> - HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) + typename Layout = DefaultLayout, + typename = std::enable_if_t< + (std::is_convertible_v, std::size_t> || + std::is_convertible_v, ck::long_index_t>) && + std::is_convertible_v>> + HostTensorDescriptor(const Lengths& lens, const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), {}, layout) { - this->CalculateStrides(); + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; } template && - std::is_convertible_v>> + typename = std::enable_if_t && + std::is_convertible_v>, + typename Layout = DefaultLayout> HostTensorDescriptor(const std::initializer_list& lens, - const std::initializer_list& strides) - : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + const std::initializer_list& strides, + const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), + std::vector(strides.begin(), strides.end()), + layout) { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; } + // HostTensorDescriptor({row, col}, {row_stride, col_stride}) + template HostTensorDescriptor(const std::initializer_list& lens, - const std::initializer_list& strides) - : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + const std::initializer_list& strides, + const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), + std::vector(strides.begin(), strides.end()), + layout) { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; + } + + // HostTensorDescriptor({row, col}, strides) + template + HostTensorDescriptor(const std::initializer_list& lens, + const Strides& strides, + const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), + std::vector(strides.begin(), strides.end()), + layout) + { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; } template , std::size_t> && - std::is_convertible_v, std::size_t>) || - (std::is_convertible_v, ck::long_index_t> && - std::is_convertible_v, ck::long_index_t>)>> - HostTensorDescriptor(const Lengths& lens, const Strides& strides) - : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + typename Layout = DefaultLayout, + typename = std::enable_if_t< + ((std::is_convertible_v, std::size_t> && + std::is_convertible_v, std::size_t>) || + (std::is_convertible_v, ck::long_index_t> && + std::is_convertible_v, ck::long_index_t>)) && + std::is_convertible_v>> + HostTensorDescriptor(const Lengths& lens, + const Strides& strides, + const Layout& layout = Layout{}) + : HostTensorDescriptor(std::vector(lens.begin(), lens.end()), + std::vector(strides.begin(), strides.end()), + layout) { + if(dbg) + std::cout << "HostTensorDescriptor ctor (" << __LINE__ << ")" << std::endl; } std::size_t GetNumOfDimension() const; @@ -173,15 +571,34 @@ struct HostTensorDescriptor } friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + friend std::ostream& operator<<(std::ostream& os, ChosenLayout tag); private: std::vector mLens; std::vector mStrides; + static constexpr bool dbg = false; + + /** + * @brief Converts a vector of size_t values to a vector of int values. + * + * @param vec The input vector of size_t values to be converted. + * @return std::vector A vector containing the converted int values. + */ + std::vector AsInt(const std::vector& vec) const + { + std::vector strides_int(vec.size()); + std::transform(vec.begin(), vec.end(), strides_int.begin(), [](std::size_t stride) { + return static_cast(stride); + }); + return strides_int; + } }; -template -HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, - const New2Old& new2old) +template +HostTensorDescriptor +transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, + const New2Old& new2old, + const NewLayout& new_layout = NewLayout()) { std::vector new_lengths(a.GetNumOfDimension()); std::vector new_strides(a.GetNumOfDimension()); @@ -192,7 +609,7 @@ HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTe new_strides[i] = a.GetStrides()[new2old[i]]; } - return HostTensorDescriptor(new_lengths, new_strides); + return HostTensorDescriptor(new_lengths, new_strides, new_layout); } struct joinable_thread : std::thread @@ -300,6 +717,36 @@ struct Tensor { } + template 0), int> = 0> + Tensor(std::initializer_list lens, Rest&&... rest) + : mDesc(lens, std::forward(rest)...), mData(GetElementSpaceSize()) + { + } + + template 0), int> = 0> + Tensor(std::initializer_list lens, std::initializer_list strides, Rest&&... rest) + : mDesc(lens, strides, std::forward(rest)...), mData(GetElementSpaceSize()) + { + } + + template 0), int> = 0> + Tensor(const Lengths& lens, Rest&&... rest) + : mDesc(lens, std::forward(rest)...), mData(GetElementSpaceSize()) + { + } + + template 0), int> = 0> + Tensor(const Lengths& lens, const Strides& strides, Rest&&... rest) + : mDesc(lens, strides, std::forward(rest)...), mData(GetElementSpaceSize()) + { + } + Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} template diff --git a/include/ck/library/utility/validation_common.hpp b/include/ck/library/utility/validation_common.hpp deleted file mode 100644 index 38933c6d7c..0000000000 --- a/include/ck/library/utility/validation_common.hpp +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include "ck/ck.hpp" -#include "ck/utility/type.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" - -namespace ck { -namespace utils { - -template -inline void -validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride") -{ - if(ck::is_same_v) - { - if(stride < M) - { - throw std::runtime_error( - "Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) + - ") must be greater than or equal to dim (" + std::to_string(M) + ")"); - } - } - else // RowMajor - { - if(stride < N) - { - throw std::runtime_error( - "Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) + - ") must be greater than or equal to dim (" + std::to_string(N) + ")"); - } - } -} - -// Convenience functions for common GEMM patterns -template -inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC) -{ - validate_gemm_stride(M, K, StrideA, "StrideA"); - validate_gemm_stride(K, N, StrideB, "StrideB"); - validate_gemm_stride(M, N, StrideC, "StrideC"); -} - -} // namespace utils -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ff64b6fe2a..d664a822aa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -54,6 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr auto xdlops_gemm = XdlopsGemm{}; + using ComputeDataTypeBuf = + conditional_t::value, float, ComputeDataType>; + static constexpr index_t AMmaKStride = KPack; static constexpr index_t BMmaKStride = KPack; @@ -376,7 +379,7 @@ struct BlockwiseGemmXdlops_pipeline_base make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -386,7 +389,7 @@ struct BlockwiseGemmXdlops_pipeline_base A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index f597573dc2..f281184c14 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1, @@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp index ea4f5e4a28..1af982e165 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / + // sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / + // sizeof(ADataType) : sizeof(ComputeDataTypeBuf) + // / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -351,9 +355,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -516,17 +520,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -646,17 +650,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -737,17 +741,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -791,17 +795,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number::type; xdlops_gemm.template Run<>( @@ -842,7 +846,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale, @@ -852,7 +856,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp index 4246f4a44e..123174e090 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -279,20 +281,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -360,20 +362,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp index 4cc1cf569d..b474ddf528 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -284,20 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -355,20 +357,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -410,20 +412,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -461,20 +463,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -628,6 +630,8 @@ struct BlockwiseGemmXdlops_pipeline_v2( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -786,20 +790,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -885,20 +889,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -961,20 +965,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1037,20 +1041,20 @@ struct BlockwiseGemmXdlops_pipeline_v2{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1129,7 +1133,7 @@ struct BlockwiseGemmXdlops_pipeline_v2, @@ -1139,7 +1143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp index 119f8a3306..70f31246f2 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -257,9 +259,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -351,20 +353,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -457,20 +459,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -547,20 +549,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), @@ -605,20 +607,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale{}([&](auto n0) { c_thread_buf_per_scale.Clear(); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp index 80c65515e8..aded984c1e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale= 1 ? 4 * WaveSize / BlockSize : 1; static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( @@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -285,20 +287,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -356,20 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -411,20 +413,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -462,20 +464,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -629,6 +631,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -821,20 +825,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -942,20 +946,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1039,20 +1043,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1123,20 +1127,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -1223,7 +1227,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale, @@ -1233,7 +1237,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 7203348418..f797c611a8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3 - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -295,9 +297,9 @@ struct BlockwiseGemmXdlops_pipeline_v3( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -364,20 +366,20 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -424,20 +426,20 @@ struct BlockwiseGemmXdlops_pipeline_v3{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp index a7d22066ac..3f4f7ea7e8 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -329,9 +331,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale{}) == 1, "Pipeline v3 only support scaleblocksliceN=1"); // assume kperblock = scaleblockk - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( a_scale_thread_desc.GetElementSpaceSize()); @@ -476,20 +478,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; xdlops_gemm.template Run<>( @@ -578,20 +580,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale()(Number{}) = 0; }); static_for<0, KRepeat, 1>{}([&](auto k0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; xdlops_gemm.template Run<>( a_thread_vec.template AsType(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp index 3179a90b7f..35be8b9551 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale - // sizeof(ComputeDataType) / sizeof(BDataType) - // ? sizeof(ComputeDataType) / sizeof(ADataType) - // : sizeof(ComputeDataType) / sizeof(BDataType); + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) > + // sizeof(ComputeDataTypeBuf) / sizeof(BDataType) + // ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType) + // : sizeof(ComputeDataTypeBuf) / sizeof(BDataType); constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); @@ -307,13 +309,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // B scale buffer - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); // Global prefetch 1 @@ -429,20 +431,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -491,20 +493,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp index 9835d9325b..c762b3be15 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); StaticallyIndexedArray{}> a_thread_bufs; @@ -369,22 +371,22 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf] [Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf] [Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -439,20 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -492,20 +494,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -524,20 +526,20 @@ struct BlockwiseGemmXdlops_pipeline_v4{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp index f35c7a97cc..3819f572c0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // B scale buffer - auto b_scale_thread_buf = make_static_buffer( + auto b_scale_thread_buf = make_static_buffer( b_scale_thread_desc.GetElementSpaceSize()); StaticallyIndexedArray{}> a_thread_bufs; @@ -478,22 +480,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf] [Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf] [Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -549,20 +551,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -603,20 +605,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -635,20 +637,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_bufs[mfma_reg_buf][Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_bufs[mfma_reg_buf][Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp index 99934fa74e..d5bc6369dd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v5( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); // Global prefetch 1 @@ -405,8 +407,8 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat, 1>{}([&](auto k0) { if constexpr(k0 == (KRepeat - 1)) @@ -427,18 +429,18 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; constexpr index_t c_offset = @@ -481,8 +483,8 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat, 1>{}([&](auto k0) { if constexpr(k0 == (KRepeat - 1)) @@ -497,18 +499,18 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -540,25 +542,25 @@ struct BlockwiseGemmXdlops_pipeline_v5 a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KRepeat - 1, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -591,16 +593,16 @@ struct BlockwiseGemmXdlops_pipeline_v5{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = a_thread_buf + a_thread_vec.template AsType()(ik) = a_thread_buf [Number{}]; }); static_for<0, KPack, 1>{}([&](auto ik) { - b_thread_vec.template AsType()(ik) = b_thread_buf + b_thread_vec.template AsType()(ik) = b_thread_buf [Number{}]; }); using mfma_input_type = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -637,7 +639,7 @@ struct BlockwiseGemmXdlops_pipeline_v5{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -647,7 +649,7 @@ struct BlockwiseGemmXdlops_pipeline_v5; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp index cbb9fadc6d..5de33c90fe 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -55,6 +55,155 @@ struct DeviceGemmMultipleABD : public BaseOperator virtual std::unique_ptr MakeInvokerPointer() = 0; }; +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABDSplitK : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleABDSplitK in contexts where DeviceGemmMultipleABD is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleABD and DeviceGemmMultipleABDSplitK +/// is that DeviceGemmMultipleABDSplitK::MakeArgumentPointer requires an additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleABDSplitKWrapper : public DeviceGemmMultipleABD +{ + + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleABDSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_as, + p_bs, + p_ds, + p_e, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, // KBatch + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index c00078186f..e305dbfd9a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -64,9 +64,27 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + // shift A matrices pointer for splitk + typename GridwiseGemm::AsGridPointer p_as_grid_shift; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType_ = + remove_cvref_t>; + p_as_grid_shift(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i] + a_batch_offset; + }); + + // shift B matrices pointer for splitk + typename GridwiseGemm::BsGridPointer p_bs_grid_shift; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType_ = + remove_cvref_t>; + p_bs_grid_shift(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i] + b_batch_offset; + }); + GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset, + p_as_grid_shift, + p_bs_grid_shift, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, p_shared, @@ -278,8 +296,8 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -346,15 +364,15 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm{p_a_grid_}, + std::array{p_b_grid_}, std::array{}, // p_ds_grid_ p_c_grid_, M_, N_, K_, - StrideA_, - StrideB_, + std::array{StrideA_}, + std::array{StrideB_}, std::array{}, // StrideDs_ StrideC_, k_batch_, @@ -423,26 +441,33 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm rotating_mem( - arg_, - stream_config.rotating_count, - arg_.Batch * size_a_buffer, - arg_.Batch * size_b_buffer); + std::array size_as_buffers; + size_as_buffers[0] = a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize * arg_.Batch; + + std::array size_bs_buffers; + size_bs_buffers[0] = b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize * arg_.Batch; + + ck::utility::RotatingMemWrapperMultiABD, + Tuple, + Tuple<>> + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + std::array{}); rotating_mem.Print(); auto run_flush_cache = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..48914479bc --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support and multiple D tensors. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// E{M,N} = CDE_op(A_op(As{M,K}...) * B_op(Bs{K,N}...), Ds{M,N}...) +/// Where As, Bs, Ds are input tensors and E is the output tensor. The A/B_op are +/// elementwise +// operations that could be applied on each tensor respectively. The CDE_op is an +// elementwise operation applied to the C and all D tensors. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam AsLayout A tensors data layouts. +/// @tparam BsLayout B tensors data layouts. +/// @tparam DsLayout D tensors data layouts. +/// @tparam ELayout E tensor data layout. +/// @tparam AsDataType A tensors data types. +/// @tparam BsDataType B tensors data types. +/// @tparam DsDataType D tensors data types. +/// @tparam EDataType E tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CDEElementwiseOperation Elementwise operation applied to the C output tensor (after +/// GEMM) and D input tensors. +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CDEShuffleBlockTransferScalarPerVectors The size of vectorized memory access. +/// Used when loading data from D tensors and storing data +/// to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemmMultipleABD_Wmma_CShuffleV3 + : public DeviceGemmMultipleABDSplitK +{ + // Note: Pass multiple layout but then using only the first one + // This is to replicate xdl functionality but it should be extended + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + AsDataType, + BsDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + static_cast(p_e), + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + static_cast(p_e), + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultipleABD_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", "; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + + str << std::string(ALayout_::name)[0]; + }); + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + + str << std::string(BLayout_::name)[0]; + }); + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + str << std::string(DLayout::name)[0]; + }); + str << std::string(ELayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"<()) + { + __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; - GridwiseGemmWelford::template Run( - p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_welford_mean_grid, - p_welford_var_grid, - p_welford_count_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - mean_var_grid_desc_mblock_mperblock_nblock, - count_grid_desc_mblock_mperblock_nblock, - block_2_etile_map, - NRaw); + GridwiseGemmWelford::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_welford_mean_grid, + p_welford_var_grid, + p_welford_count_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + mean_var_grid_desc_mblock_mperblock_nblock, + count_grid_desc_mblock_mperblock_nblock, + block_2_etile_map, + NRaw); + } #else ignore = p_a_grid; ignore = p_b_grid; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp index 6cd5020642..b7cc7bd7d0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -193,8 +193,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 BLayout, DsLayout, ELayout, - ADataType, - BDataType, + Tuple, + Tuple, AccDataType, CShuffleDataType, DsDataType, @@ -244,8 +244,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, DsDataType, EDataType, MPerBlock, @@ -291,15 +291,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - return Argument{static_cast(p_a), - static_cast(p_b), + return Argument{std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_e), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, StrideDs, StrideE, KBatch, @@ -328,15 +328,15 @@ struct DeviceGemmMultipleD_Wmma_CShuffleV3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_e), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, StrideDs, StrideE, KBatch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index f1eb5e5d64..2ceeb39bac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -182,8 +182,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -233,8 +233,8 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2, + Tuple, Tuple<>, CDataType, MPerBlock, @@ -283,15 +283,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, KBatch, @@ -317,15 +317,15 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, KBatch, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index a9d5c666a9..5e9a861f41 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -91,8 +91,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, - ADataType, - BDataType, + Tuple, + Tuple, + BScaleDataType, AccDataType, CShuffleDataType, Tuple<>, // DsDataType @@ -144,8 +145,8 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, + Tuple, Tuple<>, CDataType, MPerBlock, @@ -195,15 +196,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, StrideScaleB, @@ -233,15 +234,15 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale(static_cast(p_a), - static_cast(p_b), + return std::make_unique(std::array{p_a}, + std::array{p_b}, std::array{}, // p_ds_grid_ static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, std::array{}, // StrideDs_ StrideC, StrideScaleB, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 72191632d8..4269d67d12 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -23,8 +23,8 @@ namespace tensor_operation { namespace device { template size_as_buffers; + static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + size_as_buffers[i] = a_grid_desc_ak0_m_ak1[i].GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + }); + + std::array size_bs_buffers; + static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + size_bs_buffers[i] = b_grid_desc_bk0_n_bk1[i].GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + }); const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); @@ -108,12 +117,13 @@ struct DeviceGemm_Wmma_CShuffleV3_Common ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); }); - ck::utility::RotatingMemWrapperMultiD rotating_mem( - arg_, - stream_config.rotating_count, - size_a_buffer, - size_b_buffer, - size_ds_buffers); + ck::utility:: + RotatingMemWrapperMultiABD + rotating_mem(arg_, + stream_config.rotating_count, + size_as_buffers, + size_bs_buffers, + size_ds_buffers); rotating_mem.Print(); auto run_flush_cache = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index 3a06ea8451..df51a2aa27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -98,8 +98,8 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1, CLayout, - ADataType, - BDataType, + Tuple, + Tuple, GemmAccDataType, ReduceDataType, Tuple<>, @@ -147,15 +147,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1 p_a_grid_, + std::array p_b_grid_, const ::std::array p_ds_, CDataType* p_c_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideA_, + std::array StrideB_, const ::std::array stride_ds_, index_t StrideC_, index_t KBatch_, @@ -430,15 +430,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1{p_a}, + std::array{p_b}, p_ds, p_c, M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, stride_ds, StrideC, KBatch, @@ -472,15 +472,15 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1(static_cast(p_a), - static_cast(p_b), + return ::std::make_unique(std::array{p_a}, + std::array{p_b}, p_ds, static_cast(p_c), M, N, K, - StrideA, - StrideB, + std::array{StrideA}, + std::array{StrideB}, DsStrides, StrideC, KSplit, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index bc192b7651..4abd14b080 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -321,12 +321,6 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -352,8 +345,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm, remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -384,8 +376,8 @@ struct DeviceGemmXdlSkipBLds : public DeviceGemm{}), - make_tuple(Sequence<0>{})); + return transform_tensor_descriptor(descriptor, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); } else { @@ -616,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle using RDataType = remove_cvref_t>; // R pointer - p_rs_grid_(i) = static_cast(p_rs[i]); + p_rs_grid_(i) = static_cast(p_rs[i]); + compute_ptr_offset_of_batch_.BatchStrideRs_(i) = r_g_n_wos_strides[0]; }); } diff --git a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp index e836e73a1d..79deb81512 100644 --- a/include/ck/tensor_operation/gpu/device/tensor_layout.hpp +++ b/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -8,21 +8,31 @@ namespace tensor_layout { struct BaseTensorLayout { + static constexpr const char* name = "BaseTensorLayout"; +}; + +struct BypassLayoutVerification : public BaseTensorLayout +{ + static constexpr const char* name = "BypassLayoutVerification"; }; namespace gemm { -struct RowMajor : public BaseTensorLayout +struct BaseGemmLayout : public BaseTensorLayout +{ + static constexpr const char* name = "BaseConvolutionLayout"; +}; +struct RowMajor : public BaseGemmLayout { static constexpr const char* name = "RowMajor"; }; -struct ColumnMajor : public BaseTensorLayout +struct ColumnMajor : public BaseGemmLayout { static constexpr const char* name = "ColumnMajor"; }; -struct MFMA : public BaseTensorLayout +struct MFMA : public BaseGemmLayout { static constexpr const char* name = "MFMA"; }; @@ -31,405 +41,410 @@ struct MFMA : public BaseTensorLayout namespace convolution { +struct BaseConvolutionLayout : public BaseTensorLayout +{ + static constexpr const char* name = "BaseConvolutionLayout"; +}; + // input tensor // packed NCW/NCHW/NCDHW -struct NCW : public BaseTensorLayout +struct NCW : public BaseConvolutionLayout { static constexpr const char* name = "NCW"; }; -struct NCHW : public BaseTensorLayout +struct NCHW : public BaseConvolutionLayout { static constexpr const char* name = "NCHW"; }; -struct NCDHW : public BaseTensorLayout +struct NCDHW : public BaseConvolutionLayout { static constexpr const char* name = "NCDHW"; }; // packed GNCW/GNCHW/GNCDHW -struct GNCW : public BaseTensorLayout +struct GNCW : public BaseConvolutionLayout { static constexpr const char* name = "GNCW"; }; -struct GNCHW : public BaseTensorLayout +struct GNCHW : public BaseConvolutionLayout { static constexpr const char* name = "GNCHW"; }; -struct GNCDHW : public BaseTensorLayout +struct GNCDHW : public BaseConvolutionLayout { static constexpr const char* name = "GNCDHW"; }; // input tensor // packed NWC/NHWC/NDHWC -struct NWC : public BaseTensorLayout +struct NWC : public BaseConvolutionLayout { static constexpr const char* name = "NWC"; }; -struct NHWC : public BaseTensorLayout +struct NHWC : public BaseConvolutionLayout { static constexpr const char* name = "NHWC"; }; -struct NDHWC : public BaseTensorLayout +struct NDHWC : public BaseConvolutionLayout { static constexpr const char* name = "NDHWC"; }; // input tensor // packed GNWC/GNHWC/GNDHWC -struct GNWC : public BaseTensorLayout +struct GNWC : public BaseConvolutionLayout { static constexpr const char* name = "GNWC"; }; -struct GNHWC : public BaseTensorLayout +struct GNHWC : public BaseConvolutionLayout { static constexpr const char* name = "GNHWC"; }; -struct GNDHWC : public BaseTensorLayout +struct GNDHWC : public BaseConvolutionLayout { static constexpr const char* name = "GNDHWC"; }; // for input bias -struct GC : public BaseTensorLayout +struct GC : public BaseConvolutionLayout { static constexpr const char* name = "GC"; }; // input tensor // packed NWGC/NHWGC/NDHWGC -struct NWGC : public BaseTensorLayout +struct NWGC : public BaseConvolutionLayout { static constexpr const char* name = "NWGC"; }; -struct NHWGC : public BaseTensorLayout +struct NHWGC : public BaseConvolutionLayout { static constexpr const char* name = "NHWGC"; }; -struct NDHWGC : public BaseTensorLayout +struct NDHWGC : public BaseConvolutionLayout { static constexpr const char* name = "NDHWGC"; }; // input tensor // packed NGCW/NGCHW/NGCDHW -struct NGCW : public BaseTensorLayout +struct NGCW : public BaseConvolutionLayout { static constexpr const char* name = "NGCW"; }; -struct NGCHW : public BaseTensorLayout +struct NGCHW : public BaseConvolutionLayout { static constexpr const char* name = "NGCHW"; }; -struct NGCDHW : public BaseTensorLayout +struct NGCDHW : public BaseConvolutionLayout { static constexpr const char* name = "NGCDHW"; }; // input tensor // strided layout -struct G_NW_C : public BaseTensorLayout +struct G_NW_C : public BaseConvolutionLayout { static constexpr const char* name = "G_NW_C"; }; -struct G_NHW_C : public BaseTensorLayout +struct G_NHW_C : public BaseConvolutionLayout { static constexpr const char* name = "G_NHW_C"; }; -struct G_NDHW_C : public BaseTensorLayout +struct G_NDHW_C : public BaseConvolutionLayout { static constexpr const char* name = "G_NDHW_C"; }; // for input bias -struct G_C : public BaseTensorLayout +struct G_C : public BaseConvolutionLayout { static constexpr const char* name = "G_C"; }; // weight tensor // packed KCX/KCYX/KCZYX -struct KCX : public BaseTensorLayout +struct KCX : public BaseConvolutionLayout { static constexpr const char* name = "KCX"; }; -struct KCYX : public BaseTensorLayout +struct KCYX : public BaseConvolutionLayout { static constexpr const char* name = "KCYX"; }; -struct KCZYX : public BaseTensorLayout +struct KCZYX : public BaseConvolutionLayout { static constexpr const char* name = "KCZYX"; }; // weight tensor // packed KCX/KCYX/KCZYX -struct GKCX : public BaseTensorLayout +struct GKCX : public BaseConvolutionLayout { static constexpr const char* name = "GKCX"; }; -struct GKCYX : public BaseTensorLayout +struct GKCYX : public BaseConvolutionLayout { static constexpr const char* name = "GKCYX"; }; -struct GKCZYX : public BaseTensorLayout +struct GKCZYX : public BaseConvolutionLayout { static constexpr const char* name = "GKCZYX"; }; // weight tensor // packed KXC/KYXC/KZYXC -struct KXC : public BaseTensorLayout +struct KXC : public BaseConvolutionLayout { static constexpr const char* name = "KXC"; }; -struct KYXC : public BaseTensorLayout +struct KYXC : public BaseConvolutionLayout { static constexpr const char* name = "KYXC"; }; -struct KZYXC : public BaseTensorLayout +struct KZYXC : public BaseConvolutionLayout { static constexpr const char* name = "KZYXC"; }; // weight tensor // packed GKXC/GKYXC/GKZYXC -struct GKXC : public BaseTensorLayout +struct GKXC : public BaseConvolutionLayout { static constexpr const char* name = "GKXC"; }; -struct GKYXC : public BaseTensorLayout +struct GKYXC : public BaseConvolutionLayout { static constexpr const char* name = "GKYXC"; }; -struct GKZYXC : public BaseTensorLayout +struct GKZYXC : public BaseConvolutionLayout { static constexpr const char* name = "GKZYXC"; }; // weight tensor // packed KXGC/KYXGC/KZYXGC -struct KXGC : public BaseTensorLayout +struct KXGC : public BaseConvolutionLayout { static constexpr const char* name = "KXGC"; }; -struct KYXGC : public BaseTensorLayout +struct KYXGC : public BaseConvolutionLayout { static constexpr const char* name = "KYXGC"; }; -struct KZYXGC : public BaseTensorLayout +struct KZYXGC : public BaseConvolutionLayout { static constexpr const char* name = "KZYXGC"; }; // weight tensor // strided -struct G_K_X_C : public BaseTensorLayout +struct G_K_X_C : public BaseConvolutionLayout { static constexpr const char* name = "G_K_X_C"; }; -struct G_K_YX_C : public BaseTensorLayout +struct G_K_YX_C : public BaseConvolutionLayout { static constexpr const char* name = "G_K_YX_C"; }; -struct G_K_ZYX_C : public BaseTensorLayout +struct G_K_ZYX_C : public BaseConvolutionLayout { static constexpr const char* name = "G_K_ZYX_C"; }; // output tensor // packed NKW/NKHW/NKDHW -struct NKW : public BaseTensorLayout +struct NKW : public BaseConvolutionLayout { static constexpr const char* name = "NKW"; }; -struct NKHW : public BaseTensorLayout +struct NKHW : public BaseConvolutionLayout { static constexpr const char* name = "NKHW"; }; -struct NKDHW : public BaseTensorLayout +struct NKDHW : public BaseConvolutionLayout { static constexpr const char* name = "NKDHW"; }; // output tensor // packed GNKW/GNKHW/GNKDHW -struct GNKW : public BaseTensorLayout +struct GNKW : public BaseConvolutionLayout { static constexpr const char* name = "GNKW"; }; -struct GNKHW : public BaseTensorLayout +struct GNKHW : public BaseConvolutionLayout { static constexpr const char* name = "GNKHW"; }; -struct GNKDHW : public BaseTensorLayout +struct GNKDHW : public BaseConvolutionLayout { static constexpr const char* name = "GNKDHW"; }; // output tensor // packed NWK/NHWK/NDHWK -struct NWK : public BaseTensorLayout +struct NWK : public BaseConvolutionLayout { static constexpr const char* name = "NWK"; }; -struct NHWK : public BaseTensorLayout +struct NHWK : public BaseConvolutionLayout { static constexpr const char* name = "NHWK"; }; -struct NDHWK : public BaseTensorLayout +struct NDHWK : public BaseConvolutionLayout { static constexpr const char* name = "NDHWK"; }; // output tensor // packed GNWK/GNHWK/GNDHWK -struct GNWK : public BaseTensorLayout +struct GNWK : public BaseConvolutionLayout { static constexpr const char* name = "GNWK"; }; -struct GNHWK : public BaseTensorLayout +struct GNHWK : public BaseConvolutionLayout { static constexpr const char* name = "GNHWK"; }; -struct GNDHWK : public BaseTensorLayout +struct GNDHWK : public BaseConvolutionLayout { static constexpr const char* name = "GNDHWK"; }; // output tensor // packed NWGK/NHWGK/NDHWGK -struct NWGK : public BaseTensorLayout +struct NWGK : public BaseConvolutionLayout { static constexpr const char* name = "NWGK"; }; -struct NHWGK : public BaseTensorLayout +struct NHWGK : public BaseConvolutionLayout { static constexpr const char* name = "NHWGK"; }; -struct NDHWGK : public BaseTensorLayout +struct NDHWGK : public BaseConvolutionLayout { static constexpr const char* name = "NDHWGK"; }; -struct NGKW : public BaseTensorLayout +struct NGKW : public BaseConvolutionLayout { static constexpr const char* name = "NGKW"; }; -struct NGKHW : public BaseTensorLayout +struct NGKHW : public BaseConvolutionLayout { static constexpr const char* name = "NGKHW"; }; -struct NGKDHW : public BaseTensorLayout +struct NGKDHW : public BaseConvolutionLayout { static constexpr const char* name = "NGKDHW"; }; // output tensor // strided layout -struct G_NW_K : public BaseTensorLayout +struct G_NW_K : public BaseConvolutionLayout { static constexpr const char* name = "G_NW_K"; }; -struct G_NHW_K : public BaseTensorLayout +struct G_NHW_K : public BaseConvolutionLayout { static constexpr const char* name = "G_NHW_K"; }; -struct G_NDHW_K : public BaseTensorLayout +struct G_NDHW_K : public BaseConvolutionLayout { static constexpr const char* name = "G_NDHW_K"; }; // for output bias -struct G_K : public BaseTensorLayout +struct G_K : public BaseConvolutionLayout { static constexpr const char* name = "G_K"; }; // K-reduced output tensor (packed) -struct GNW : public BaseTensorLayout +struct GNW : public BaseConvolutionLayout { static constexpr const char* name = "GNW"; }; -struct GNHW : public BaseTensorLayout +struct GNHW : public BaseConvolutionLayout { static constexpr const char* name = "GNHW"; }; -struct GNDHW : public BaseTensorLayout +struct GNDHW : public BaseConvolutionLayout { static constexpr const char* name = "GNDHW"; }; // K-reduced output tensor (packed) -struct NWG : public BaseTensorLayout +struct NWG : public BaseConvolutionLayout { static constexpr const char* name = "NWG"; }; -struct NHWG : public BaseTensorLayout +struct NHWG : public BaseConvolutionLayout { static constexpr const char* name = "NHWG"; }; -struct NDHWG : public BaseTensorLayout +struct NDHWG : public BaseConvolutionLayout { static constexpr const char* name = "NDHWG"; }; // K-reduced output tensor (strided) -struct G_NW : public BaseTensorLayout +struct G_NW : public BaseConvolutionLayout { static constexpr const char* name = "G_NW"; }; -struct G_NHW : public BaseTensorLayout +struct G_NHW : public BaseConvolutionLayout { static constexpr const char* name = "G_NHW"; }; -struct G_NDHW : public BaseTensorLayout +struct G_NDHW : public BaseConvolutionLayout { static constexpr const char* name = "G_NDHW"; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index cbad6a5673..ad28a12e57 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -107,8 +107,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle using BComputeDataType = conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; - using BComputeDataType = BComputeDataType_; + // Element data type is used in LDS and registers. ComputeDataType_ is inside mfma, eg tf32. + using AElementDataType = + conditional_t, float, AComputeDataType_>; + using BElementDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -199,8 +202,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + - b_block_space_size_aligned * sizeof(BComputeDataType), + return math::max(a_block_space_size_aligned * sizeof(AElementDataType) + + b_block_space_size_aligned * sizeof(BElementDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -621,7 +624,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, AsDataType, - Tuple, + Tuple, decltype(as_grid_desc_ak0_m_ak1), decltype(tie(a_block_desc_ak0_m_ak1)), AElementwiseOperation, @@ -649,7 +652,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, BsDataType, - Tuple, + Tuple, decltype(bs_grid_desc_bk0_n_bk1), decltype(tie(b_block_desc_bk0_n_bk1)), BElementwiseOperation, @@ -679,27 +682,28 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // sanity check constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); constexpr bool is_single_rate_mfma = - (((is_same::value || - is_same::value) && + (((is_same::value || + is_same::value) && lcm_AK1_BK1 <= 4) || - (is_same::value && lcm_AK1_BK1 <= 8) || - ((is_same::value || is_same::value) && + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || + is_same::value) && lcm_AK1_BK1 < 32)) ? true : false; static constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - AComputeDataType, - BComputeDataType, + AElementDataType, + BElementDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -709,8 +713,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle NXdlPerWave, KPack, LoopSched, - AComputeDataType, - BComputeDataType>(); + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -719,10 +723,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index bd2a8b04bc..d226510cf0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -11,6 +11,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -39,8 +40,8 @@ namespace ck { /// @tparam BLayout B tensor data layout. /// @tparam DsLayout D tensors data layouts. /// @tparam ELayout E tensor data layout. -/// @tparam ADataType A tensor data type. -/// @tparam BDataType B tensor data type. +/// @tparam AsDataType A tensors data types. +/// @tparam BsDataType B tensors data types. /// @tparam AccDataType The accumulation data type related to the hardware /// matrix-multiplication instruction. /// @tparam CShuffleDataType The data type used to store matrix-multiplication results into @@ -129,8 +130,8 @@ template StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t KBatch_) : M{M_}, N{N_}, K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, KBatch{KBatch_}, @@ -355,7 +362,15 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + << "SAs: {"; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : ""); + }); + std::cout << "}, " << "SBs: {"; + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : ""); + }); + std::cout << "}, "; if constexpr(NumDTensor > 0) { std::cout << "SDs: { "; @@ -373,8 +388,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t M; index_t N; index_t K; - index_t StrideA; - index_t StrideB; + std::array StrideAs; + std::array StrideBs; std::array StrideDs; index_t StrideE; index_t KBatch; @@ -391,15 +406,15 @@ struct GridwiseGemm_wmma_cshuffle_v3 // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, std::array p_ds_grid_, EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t k_batch_, @@ -407,9 +422,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, k_batch_}, - p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, + : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideE_, k_batch_}, + p_as_grid{}, + p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, a_element_op{a_element_op_}, @@ -417,9 +432,27 @@ struct GridwiseGemm_wmma_cshuffle_v3 cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + + // populate pointer, desc for Ds static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; + // D pointer p_ds_grid(i) = static_cast(p_ds_grid_[i]); }); } @@ -434,8 +467,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 return (Problem::KBatch > 1) && (!is_reduce); } - const ADataType* p_a_grid; - const BDataType* p_b_grid; + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; DsGridPointer p_ds_grid; EDataType* p_e_grid; @@ -452,29 +485,39 @@ struct GridwiseGemm_wmma_cshuffle_v3 __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { + // Note: in xdl implementation multiple AB supports one layout + // but multiple strides, so we create an array of offsets with + // the same values. + // It should be fixed later on. Once we will have a thread transfer + // more flexible. if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; }); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); } } @@ -497,8 +540,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 } } - index_t a_k_split_offset; - index_t b_k_split_offset; + std::array a_k_split_offset; + std::array b_k_split_offset; index_t c_reduce_offset; }; @@ -514,8 +557,8 @@ struct GridwiseGemm_wmma_cshuffle_v3 template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, void* p_shared, @@ -524,10 +567,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( @@ -562,20 +605,20 @@ struct GridwiseGemm_wmma_cshuffle_v3 const index_t num_k_block_per_scale = GetKBlockPerScale(); - Base::template Run(p_a_grid, - p_b_grid, + TailNum>(p_as_grid, + p_bs_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -595,10 +638,26 @@ struct GridwiseGemm_wmma_cshuffle_v3 __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { + // shift A matrices pointer for splitk + AsGridPointer p_as_grid_splitk; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_splitk(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i]; + }); + + // shift B matrices pointer for splitk + BsGridPointer p_bs_grid_splitk; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_splitk(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i]; + }); + Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, p_shared, karg, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 29c5ae31cd..46de6b156a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -22,8 +22,9 @@ template { - using BScaleType = ck::half_t; - using Base = GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, - ADataType, - BDataType, + AsDataType, + BsDataType, AccDataType, CShuffleDataType, DsDataType, @@ -202,8 +201,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::CalculateMPadded; using Base::CalculateNBlock; using Base::CalculateNPadded; - using Base::MakeAGridDescriptor_AK0_M_AK1; - using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeAsGridDescriptor_AK0_M_AK1; + using Base::MakeBsGridDescriptor_BK0_N_BK1; using Base::MakeDEGridDescriptor_M_N; using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; @@ -217,7 +216,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + using Base::NumATensor; + using Base::NumBTensor; using Base::NumDTensor; + using typename Base::AsGridPointer; + using typename Base::BsGridPointer; using typename Base::DsGridPointer; struct Problem @@ -225,8 +228,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ Problem(index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t StrideScaleB_, @@ -234,8 +237,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale : M{M_}, N{N_}, K{K_}, - StrideA{StrideA_}, - StrideB{StrideB_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, StrideScaleB{StrideScaleB_}, @@ -254,7 +257,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __host__ void Print() const { std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " - << "SA:" << StrideA << ", " << "SB:" << StrideB << ", "; + << "SAs: {"; + static_for<0, NumATensor, 1>{}([&](auto i) { + std::cout << StrideAs[i] << (i.value < NumATensor - 1 ? ", " : ""); + }); + std::cout << "}, " << "SBs: {"; + static_for<0, NumBTensor, 1>{}([&](auto i) { + std::cout << StrideBs[i] << (i.value < NumBTensor - 1 ? ", " : ""); + }); + std::cout << "}, "; if constexpr(NumDTensor > 0) { std::cout << "SDs: { "; @@ -273,8 +284,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t M; index_t N; index_t K; - index_t StrideA; - index_t StrideB; + std::array StrideAs; + std::array StrideBs; std::array StrideDs; index_t StrideE; index_t StrideScaleB; @@ -292,15 +303,15 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale // Argument struct Argument : public tensor_operation::device::BaseArgument, public Problem { - __host__ Argument(const ADataType* p_a_grid_, - const BDataType* p_b_grid_, + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, std::array p_ds_grid_, EDataType* p_e_grid_, index_t M_, index_t N_, index_t K_, - index_t StrideA_, - index_t StrideB_, + std::array StrideAs_, + std::array StrideBs_, std::array StrideDs_, index_t StrideE_, index_t StrideScaleB_, @@ -310,9 +321,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_ = false) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideE_, StrideScaleB_, k_batch_}, - p_a_grid{p_a_grid_}, - p_b_grid{p_b_grid_}, + : Problem{M_, + N_, + K_, + StrideAs_, + StrideBs_, + StrideDs_, + StrideE_, + StrideScaleB_, + k_batch_}, + p_as_grid{}, + p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, p_b_scale_grid{p_b_scale_grid_}, @@ -321,6 +340,22 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale cde_element_op{cde_element_op_}, is_reduce(is_reduce_) { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { using DDataType = remove_cvref_t>; @@ -338,8 +373,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale return (Problem::KBatch > 1) && (!is_reduce); } - const ADataType* p_a_grid; - const BDataType* p_b_grid; + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; DsGridPointer p_ds_grid; EDataType* p_e_grid; @@ -355,29 +390,39 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { + // Note: in xdl implementation multiple AB supports one layout + // but multiple strides, so we create an array of offsets with + // the same values. + // It should be fixed later on. Once we will have a thread transfer + // more flexible. if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead / APackedSize; }); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + static_for<0, NumATensor, 1>{}( + [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); } else if constexpr(is_same_v) { if constexpr(!PermuteB) { - b_k_split_offset = k_id * karg.KRead / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); } else { const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = k_id * k0_offset / BPackedSize; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); } } @@ -410,8 +455,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale } } - index_t a_k_split_offset; - index_t b_k_split_offset; + std::array a_k_split_offset; + std::array b_k_split_offset; index_t scale_k_split_offset; // New member for scale matrix offset index_t c_reduce_offset; }; @@ -423,7 +468,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template + template __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, const BScaleType* p_b_scale_grid, index_t block_n_id) @@ -488,8 +533,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, const BScaleType* p_b_scale_grid, @@ -499,10 +544,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) { - const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( - problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); - const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( - problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N( @@ -542,20 +587,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const index_t num_k_block_per_scale = GetKBlockPerScale(); - Base::template Run(p_a_grid, - p_b_grid, + TailNum>(p_as_grid, + p_bs_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -575,10 +620,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg) { + // shift A matrices pointer for splitk + AsGridPointer p_as_grid_splitk; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + p_as_grid_splitk(i) = static_cast(karg.p_as_grid[i]) + + splitk_batch_offset.a_k_split_offset[i]; + }); + + // shift B matrices pointer for splitk + BsGridPointer p_bs_grid_splitk; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + p_bs_grid_splitk(i) = static_cast(karg.p_bs_grid[i]) + + splitk_batch_offset.b_k_split_offset[i]; + }); + Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, //; + splitk_batch_offset.c_reduce_offset, + p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, p_shared, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 59d3a6a4c5..dac0c9b3b0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -61,8 +62,8 @@ template {}; static constexpr auto I7 = Number<7>{}; + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + + using LDSTypeA = + typename std::conditional<(NumATensor > 1), + ComputeTypeA, + remove_cvref_t>>::type; + using LDSTypeB = + typename std::conditional<(NumBTensor > 1), + ComputeTypeB, + remove_cvref_t>>::type; + static constexpr auto EShuffleBlockTransferScalarPerVector = CDEShuffleBlockTransferScalarPerVectors{}[I0]; @@ -136,14 +149,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using ThisThreadBlock = ThisThreadBlock; static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; }(); static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) + if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; @@ -230,6 +243,31 @@ struct GridwiseGemm_wmma_cshuffle_v3_base make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); } + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) { @@ -314,6 +352,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0); + }, + Number{}); + } + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) { @@ -330,7 +383,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using GemmSpecialization = tensor_operation::device::GemmSpecialization; - static_assert(!(is_same_v, pk_i4_t> && + static_assert(!(is_same_v, pk_i4_t> && GemmSpec != GemmSpecialization::Default), "pk_i4_t does not support padding"); @@ -424,6 +477,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0); + }, + Number{}); + } + template __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) { @@ -557,7 +625,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // in some cases. else if constexpr(is_same::value) { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeA) / APackedSize; constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -604,20 +672,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto KThreadRead = 64 / MPerWmma; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); constexpr auto KThreadReadPerm = (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) : KThreadRead; // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(LDSTypeA) > 128) ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + : ((128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))) > M0 ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + : 128 / (AK1Number * MPerWmma * sizeof(LDSTypeA))); constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, @@ -694,7 +762,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base else if constexpr(is_same::value) { // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeB) / BPackedSize; constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( make_tuple( @@ -738,20 +806,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto KThreadRead = 64 / NPerWmma; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); constexpr auto KThreadReadPerm = (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) : KThreadRead; // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + constexpr auto npair = (BK1Number * NPerWmma * sizeof(LDSTypeB) > 128) ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + : ((128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))) > N0 ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + : 128 / (BK1Number * NPerWmma * sizeof(LDSTypeB))); constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, @@ -836,8 +904,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, - ADataType, - BDataType, + LDSTypeA, + LDSTypeB, ComputeTypeA, ComputeTypeB, AccDataType, @@ -1120,11 +1188,24 @@ struct GridwiseGemm_wmma_cshuffle_v3_base c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat .GetElementSpaceSize(); - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize + + b_block_space_size_aligned * sizeof(LDSTypeB) / BPackedSize), c_block_size * sizeof(CShuffleDataType)); } + template + __device__ __forceinline__ static auto get_first_element_workaround(Type& array) + { + if constexpr(numElements > 1) + { + return array; + } + else + { + return array[I0]; + } + } + template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, EDataType* p_e_grid, void* p_shared, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const AGridDesc_AK0_M_K1& as_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& bs_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& @@ -1152,10 +1233,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t& num_k_block_per_scale, BScaleStruct& b_scale_struct) { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + const auto ds_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -1183,66 +1274,144 @@ struct GridwiseGemm_wmma_cshuffle_v3_base constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + // workaround because v7r2 is not as general as v4r1 + auto get_a_blockwise_transfer = [&]() { + if constexpr(NumATensor > 1) + { + const auto idx_as_block_begin = generate_tuple( + [&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + AGridDesc_AK0_M_K1, + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + remove_cvref_t>, + remove_cvref_t>, + decltype(as_grid_desc_ak0_m_ak1[I0]), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + as_grid_desc_ak0_m_ak1[I0], + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto a_blockwise_copy = get_a_blockwise_transfer(); // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + // workaround because v7r2 is not as general as v4r1 + auto get_b_blockwise_transfer = [&]() { + if constexpr(NumBTensor > 1) + { + const auto idx_bs_block_begin = generate_tuple( + [&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + return ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + BGridDesc_BK0_N_K1, + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + } + else + { + return ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + remove_cvref_t>, + remove_cvref_t>, + decltype(bs_grid_desc_bk0_n_bk1[I0]), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + bs_grid_desc_bk0_n_bk1[I0], + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + } + }; + + auto b_blockwise_copy = get_b_blockwise_transfer(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1250,12 +1419,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(LDSTypeA) / + APackedSize), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1267,25 +1436,26 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / KPerBlock); - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - b_scale_struct, - num_k_block_main_loop, - num_k_block_per_scale); + blockwise_gemm_pipeline.template Run( + get_first_element_workaround(as_grid_desc_ak0_m_ak1), + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + get_first_element_workaround(as_grid_buf), + a_block_buf, + a_block_slice_copy_step, + get_first_element_workaround(bs_grid_desc_bk0_n_bk1), + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + get_first_element_workaround(bs_grid_buf), + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); // shuffle C and write out { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp index 9e524c5a23..cf3040d1ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -21,8 +21,7 @@ template (p_a_grid, p_b_grid, p_c_grid, @@ -67,8 +71,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) ignore = p_b_grid; ignore = p_c_grid; ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; @@ -375,20 +379,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 return cblockid_to_m0_n0_block_cluster_adaptor; } - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); - using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 = - decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{})); - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3& b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index aa7ce1f5b6..d2418c0913 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -164,6 +164,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using ThisThreadBlock = ThisThreadBlock; + using ElementDataTypeAB = conditional_t, float, FloatAB>; + __host__ static auto CalculateGridSize(index_t M, index_t N) { return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); @@ -236,8 +238,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // Argument struct Argument : public Problem, public tensor_operation::device::BaseArgument { - __host__ Argument(const FloatAB* p_a_grid_, - const FloatAB* p_b_grid_, + __host__ Argument(const ElementDataTypeAB* p_a_grid_, + const ElementDataTypeAB* p_b_grid_, FloatC* p_c_grid_, index_t M_, index_t N_, @@ -252,8 +254,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { } - const FloatAB* p_a_grid; - const FloatAB* p_b_grid; + const ElementDataTypeAB* p_a_grid; + const ElementDataTypeAB* p_b_grid; FloatC* p_c_grid; }; @@ -329,7 +331,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + return (a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ElementDataTypeAB); } template < @@ -450,8 +453,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + K1, + FloatABAdjusted, + FloatABAdjusted>; return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); } @@ -471,8 +476,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N> - __device__ static void Run(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, + __device__ static void Run(const ElementDataTypeAB* p_a_grid, + const ElementDataTypeAB* p_b_grid, FloatC* p_c_grid, void* __restrict__ p_shared, const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, @@ -533,8 +538,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Sequence, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, decltype(a_grid_desc_k0_m_k1), decltype(a_block_desc_k0_m_k1), ABlockTransferSrcAccessOrder, @@ -564,8 +569,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 Sequence, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, decltype(b_grid_desc_k0_n_k1), decltype(b_block_desc_k0_n_k1), BBlockTransferSrcAccessOrder, @@ -595,8 +600,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // sanity check auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - FloatABAdjusted, - FloatABAdjusted, + ElementDataTypeAB, + ElementDataTypeAB, FloatAcc, decltype(a_block_desc_k0_m_k1), decltype(b_block_desc_k0_n_k1), @@ -605,7 +610,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 MXdlPerWave, NXdlPerWave, K1, - LoopSched>(); + LoopSched, + FloatABAdjusted, + FloatABAdjusted>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -614,10 +621,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0_n_k1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index a97d9589cf..a86aa2f8ef 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -20,6 +20,7 @@ static constexpr bool is_scale_mfma_data_type() is_same_v || is_same_v; } +#ifndef CK_CODE_GEN_RTC /** * @brief Define scale data types that have hardware support for MX GEMMs */ @@ -28,6 +29,7 @@ static constexpr bool is_scale_mfma_scale_type() { return is_same_v; } +#endif /** * @brief Combination of data types that have hardware support for MX GEMMs diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 0b73f76155..c5525d5ff8 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -18,14 +18,13 @@ #define CK_USE_OCP_FP8 0 #endif -#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \ - __HIP_DEVICE_COMPILE__ +#if(defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__ #define CK_FP8_CVT_FAST_PATH 1 #else #define CK_FP8_CVT_FAST_PATH 0 #endif -#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ +#if(defined(__gfx950__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__ #define CK_OCP_FP8_CVT_FAST_PATH 1 #else #define CK_OCP_FP8_CVT_FAST_PATH 0 @@ -35,8 +34,8 @@ namespace ck { struct f8_fnuz_t { - using data_type = unsigned char; - data_type m_data; + using data_type = unsigned char; + data_type m_data = data_type{}; __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} __host__ __device__ explicit constexpr f8_fnuz_t() = default; __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const @@ -48,8 +47,8 @@ struct f8_fnuz_t struct bf8_fnuz_t { - using data_type = unsigned char; - data_type m_data; + using data_type = unsigned char; + data_type m_data = data_type{}; __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} __host__ __device__ explicit constexpr bf8_fnuz_t() = default; __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const @@ -390,7 +389,7 @@ struct bf8_ocp_t __host__ explicit operator float() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx12__) return fp8_impl::cast_to_f32_from_f8(this->data); #else return fp8_impl::cast_from_f8( @@ -404,7 +403,7 @@ struct bf8_ocp_t __host__ explicit operator _Float16() const #endif { -#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx950__) || defined(__gfx12__) return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); #else return fp8_impl::cast_from_f8<_Float16, wm, we, false>( diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index be3a5cea42..7ff8e6b057 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1647,8 +1647,8 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16> __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { #if defined(__gfx94__) - reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( - reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else ignore = reg_a; ignore = reg_b; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 984bb4d862..574269b94a 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -2,7 +2,7 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once - +#include #include "ck/utility/amd_ck_fp8.hpp" #include "ck/utility/e8m0.hpp" #include "ck/utility/statically_indexed_array.hpp" @@ -325,12 +325,14 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +#ifndef CK_CODE_GEN_RTC template <> struct scalar_type { using type = e8m0_bexp_t::type; static constexpr index_t vector_size = 1; }; +#endif template <> struct scalar_type @@ -483,8 +485,10 @@ inline const char* get_type_name() return "f8"; else if constexpr(is_same_v) return "bf8"; +#ifndef CK_CODE_GEN_RTC else if constexpr(is_same_v) return "e8m0"; +#endif else if constexpr(is_same_v) return "fp32"; #if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) diff --git a/include/ck/utility/debug.hpp b/include/ck/utility/debug.hpp index 45d443ae49..1b86b33777 100644 --- a/include/ck/utility/debug.hpp +++ b/include/ck/utility/debug.hpp @@ -13,7 +13,7 @@ template struct PrintAsType; template -struct PrintAsType::value>::type> +struct PrintAsType::value>::type> { using type = float; __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast(p)); } @@ -30,7 +30,7 @@ struct PrintAsType }; template -struct PrintAsType::value>::type> +struct PrintAsType::value>::type> { using type = int; __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast(p)); } diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 27a7545a0e..084240f84b 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1294,6 +1294,7 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +#ifndef CK_CODE_GEN_RTC template <> struct nnvb_data_t_selector { @@ -1311,6 +1312,7 @@ struct nnvb_data_t_selector { using type = e8m0_bexp_t::type; }; +#endif template <> struct nnvb_data_t_selector @@ -2270,8 +2272,10 @@ using bf6x16_t = typename vector_type::type; using bf6x16x2_t = typename vector_type::type; using bf6x32_t = typename vector_type::type; +#ifndef CK_CODE_GEN_RTC // e8m0 using e8m0x4_bexp_t = typename vector_type::type; +#endif // pack int4 using pk_i4x2_t = typename vector_type::type; diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp index f7d2a2f594..ac2a114593 100644 --- a/include/ck/utility/e8m0.hpp +++ b/include/ck/utility/e8m0.hpp @@ -3,6 +3,7 @@ #pragma once +#ifndef CK_CODE_GEN_RTC #include "ck/utility/type.hpp" namespace ck { @@ -78,3 +79,4 @@ __host__ __device__ inline constexpr int32_t get_exponent_value(e8m } // namespace utils } // namespace ck +#endif diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 748aa07f9e..94c2f84c8c 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -273,8 +273,8 @@ template __host__ __device__ Y cast_to_f8(X x, uint32_t rng) { // check datatypes - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = is_same::value; + constexpr bool is_float = is_same::value; static_assert(is_half || is_float, "Only half and float can be casted."); return run_cast_to_f8(x, rng); @@ -284,8 +284,8 @@ template __host__ __device__ Y cast_from_f8(X x) { // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = is_same::value; + constexpr bool is_float = is_same::value; static_assert(is_half || is_float, "only half and float are supported."); return run_cast_from_f8(x); diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 993b70a3fb..7227cee754 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -10,10 +10,6 @@ #include "type.hpp" #include "tuple.hpp" -#ifdef CK_CODE_GEN_RTC -#define INT32_MAX 2147483647 -#endif - namespace ck { // magic number division diff --git a/include/ck/utility/numeric_limits.hpp b/include/ck/utility/numeric_limits.hpp index e59b7eceaf..b8d6280acc 100644 --- a/include/ck/utility/numeric_limits.hpp +++ b/include/ck/utility/numeric_limits.hpp @@ -522,8 +522,6 @@ struct NumericLimits } }; -#endif - template <> struct NumericLimits { @@ -551,5 +549,6 @@ struct NumericLimits return e8m0_bexp_t(binary_142); } }; +#endif } // namespace ck diff --git a/include/ck/utility/numeric_utils.hpp b/include/ck/utility/numeric_utils.hpp index 726f667518..399bc0c3e8 100644 --- a/include/ck/utility/numeric_utils.hpp +++ b/include/ck/utility/numeric_utils.hpp @@ -10,6 +10,7 @@ struct NumericUtils { }; +#ifndef CK_CODE_GEN_RTC template <> struct NumericUtils { @@ -24,6 +25,7 @@ struct NumericUtils using bitwise_type = uint8_t; }; +#endif template <> struct NumericUtils diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index 2ff46457fc..dd2662b6d9 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -15,7 +15,7 @@ namespace ck { // Pseudo random number generator // version for fp32 -template {}, bool> = false> +template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { uint32_t x = bit_cast(val); @@ -31,7 +31,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = } // version for fp16 -template {}, bool> = false> +template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { uint16_t x = bit_cast(val); @@ -48,7 +48,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = // return 0 if data is not fp16 or fp32 template {} || std::is_same<_Float16, T>{}), bool> = false> + ck::enable_if_t{} || is_same<_Float16, T>{}), bool> = false> __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) { ck::ignore = id; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 66d760c2b3..701b2686c7 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -988,7 +988,7 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #if CK_OCP_FP8_CVT_FAST_PATH // __builtin_amdgcn_cvt_pk_f32_fp8 can produce incorrect results due to a compiler issue. // TODO: Enable when SWDEV-532959 is fixed. -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx12__) return float2_t{__builtin_amdgcn_cvt_f32_fp8(bit_cast(x), 0), __builtin_amdgcn_cvt_f32_fp8(bit_cast(x), 1)}; #else @@ -1131,7 +1131,7 @@ inline __host__ __device__ float2_t type_convert(bf8x2_oc #if CK_OCP_FP8_CVT_FAST_PATH // __builtin_amdgcn_cvt_pk_f32_bf8 can produce incorrect results due to a compiler issue. // TODO: Enable when SWDEV-532959 is fixed. -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx12__) return float2_t{__builtin_amdgcn_cvt_f32_bf8(bit_cast(x), 0), __builtin_amdgcn_cvt_f32_bf8(bit_cast(x), 1)}; #else diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 7a9c017eb2..7bc5ca5df8 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -2788,7 +2788,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_ } #if defined(__gfx950__) -template +template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { #define __LDS_ADDR __attribute__((address_space(3))) @@ -2829,6 +2829,60 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) } #endif +// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the +// memory to the SGPR registers. +__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +template , int> = 0> +__device__ inline auto amd_wave_read_first_lane(const Object& obj) +{ + constexpr size_t ObjectSize = sizeof(Object); + constexpr size_t SGPR_size = 4; + constexpr size_t NumFull = ObjectSize / SGPR_size; + constexpr size_t Tail = ObjectSize % SGPR_size; + + const unsigned char* src = reinterpret_cast(&obj); + alignas(Object) unsigned char dst[ObjectSize]; + + static_for<0, NumFull, 1>{}([&](auto Ic) { + constexpr size_t offset = Ic * SGPR_size; + uint32_t read_src; + __builtin_memcpy(&read_src, src + offset, SGPR_size); + read_src = __builtin_amdgcn_readfirstlane(read_src); + __builtin_memcpy(dst + offset, &read_src, SGPR_size); + }); + + if constexpr(Tail != 0) + { + constexpr size_t offset = NumFull * SGPR_size; + uint32_t tail_loc = 0; + __builtin_memcpy(&tail_loc, src + offset, Tail); + tail_loc = __builtin_amdgcn_readfirstlane(tail_loc); + __builtin_memcpy(dst + offset, &tail_loc, Tail); + } + Object out; + __builtin_memcpy(&out, dst, ObjectSize); + return out; +} + } // namespace ck_tile #endif // !CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 5c7ffefc6a..ce5a8075df 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -2570,6 +2570,60 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_ #endif } +// amd_wave_read_first_lane is the SGPR function from AMD GPU device to load 1 or a series of the +// memory to the SGPR registers. +__device__ inline uint32_t amd_wave_read_first_lane(uint16_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint8_t v) +{ + return __builtin_amdgcn_readfirstlane(static_cast(v)); +} + +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +template , int> = 0> +__device__ inline auto amd_wave_read_first_lane(const Object& obj) +{ + constexpr size_t ObjectSize = sizeof(Object); + constexpr size_t SGPR_size = 4; + constexpr size_t NumFull = ObjectSize / SGPR_size; + constexpr size_t Tail = ObjectSize % SGPR_size; + + const unsigned char* src = reinterpret_cast(&obj); + alignas(Object) unsigned char dst[ObjectSize]; + + static_for<0, NumFull, 1>{}([&](auto Ic) { + constexpr size_t offset = Ic * SGPR_size; + uint32_t read_src; + __builtin_memcpy(&read_src, src + offset, SGPR_size); + read_src = __builtin_amdgcn_readfirstlane(read_src); + __builtin_memcpy(dst + offset, &read_src, SGPR_size); + }); + + if constexpr(Tail != 0) + { + constexpr size_t offset = NumFull * SGPR_size; + uint32_t tail_loc = 0; + __builtin_memcpy(&tail_loc, src + offset, Tail); + tail_loc = __builtin_amdgcn_readfirstlane(tail_loc); + __builtin_memcpy(dst + offset, &tail_loc, Tail); + } + Object out; + __builtin_memcpy(&out, dst, ObjectSize); + return out; +} + template CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, @@ -2585,9 +2639,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM - T* lds_ptr = lds_base_ptr + lds_offset; - auto const lds_ptr_sgpr = - __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = amd_wave_read_first_lane((reinterpret_cast(lds_ptr))); asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), @@ -2619,7 +2672,7 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, } #if defined(__gfx950__) -template +template __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr) { #define __LDS_ADDR __attribute__((address_space(3))) diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 42f2390cde..28ded5439a 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -9,6 +9,8 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" #include "ck_tile/core/utility/ignore.hpp" #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 @@ -104,7 +106,7 @@ CK_TILE_DEVICE index_t get_warp_id(bool_constant = {}) const index_t warp_id = threadIdx.x / get_warp_size(); if constexpr(ReturnSgpr) { - return __builtin_amdgcn_readfirstlane(warp_id); + return amd_wave_read_first_lane(warp_id); } else { diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index d1e770ef42..3b747dae84 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -875,10 +875,9 @@ struct buffer_view, t_per_x, addr_space>( - p_data_ + i + linear_offset); + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + return amd_transpose_load_to_vgpr, t_per_x>(p_data_ + i + + linear_offset); #else return X{numeric>::zero()}; #endif diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index a3620453b4..2e9ab0f5c6 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -158,7 +158,4 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window -concept IsLoadableTile = requires { load_tile(std::declval()); }; - } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 284efd5d70..d29afa2d98 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -231,7 +231,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) template CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) { -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) // This API is designed to use the _pk_ serious of function constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index b45106487e..2db5d719c0 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -402,7 +402,7 @@ struct tile_window_with_static_distribution const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant{}); m0_set_with_memory( - __builtin_amdgcn_readfirstlane(m0_init_value)); // This should be wave independent + amd_wave_read_first_lane(m0_init_value)); // This should be wave independent using Traits = typename Base::Traits; diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index ecd4e81b22..052ee4ae62 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -92,13 +92,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass static constexpr index_t Block_N = Problem::BlockShape::Block_N; index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N)); using XTensorType = decltype(cast_tile(load_tile(a_window))); auto square_sum = block_reduce2d.template MakeYBlockTile(); set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { const auto a = load_tile(a_window); const auto b = load_tile(b_window); @@ -149,7 +149,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass if constexpr(kSaveX) __syncthreads(); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto x = [&]() { if constexpr(kSaveX) @@ -226,7 +226,7 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass } move_tile_window(gamma_window, {Block_N}); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto x = [&]() { if constexpr(kSaveX) diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index b0f48f6c5b..c99571562d 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -84,9 +84,9 @@ struct BatchedTransposeKernel static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput; static constexpr ck_tile::index_t VectorStrideOutput = 1; - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); - const auto offset = __builtin_amdgcn_readfirstlane(blockIdx.z * kargs.height * kargs.width); + const auto iM = amd_wave_read_first_lane(blockIdx.x * kMPerBlock); + const auto iN = amd_wave_read_first_lane(blockIdx.y * kNPerBlock); + const auto offset = amd_wave_read_first_lane(blockIdx.z * kargs.height * kargs.width); const auto x_m_n = [&]() { const auto x_dram_naive = make_naive_tensor_view( diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 9cba02464f..aacf1602ff 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -9,28 +9,9 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include +#include namespace ck_tile { - -template -concept HasDataType = requires -{ - typename T::DataType; -}; - -template -struct GetDataType -{ - using type = float; -}; - -template -requires HasDataType -struct GetDataType -{ - using type = typename T::DataType; // Use T::ScaleN::DataType -}; - template + template CK_TILE_DEVICE void scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window) { @@ -339,7 +320,7 @@ struct CShuffleEpilogue constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + constexpr auto step = SFC::get_forward_step(number{}); move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); @@ -347,10 +328,10 @@ struct CShuffleEpilogue } } - template + template CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile) { - constexpr auto idx_y_start = SFC::get_index(iAccess); + constexpr auto idx_y_start = SFC::get_index(number{}); constexpr auto mIter = number{}) / (MPerIterationShuffle)>{}; constexpr auto nIter = number{}) / (NPerIterationShuffle)>{}; @@ -405,13 +386,13 @@ struct CShuffleEpilogue /** * @brief Move both the output and D tensors windows for the next access. */ - template + template CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows) { constexpr index_t num_access = SFC::get_num_of_access(); if constexpr(iAccess != num_access - 1) { - constexpr auto step = SFC::get_forward_step(iAccess); + constexpr auto step = SFC::get_forward_step(number{}); // move the output dram window move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})}); @@ -428,6 +409,18 @@ struct CShuffleEpilogue { }; + template + struct ScaleDataType + { + using DataType = float; + }; + + template + struct ScaleDataType> + { + using DataType = typename T::DataType; + }; + template && std::is_same_v; // Tiles to hold row/col scales when present - using SMType = typename GetDataType>::type; - using SNType = typename GetDataType>::type; + using SMType = typename ScaleDataType::DataType; + using SNType = typename ScaleDataType::DataType; auto sm_tile = make_static_distributed_tensor(dram_tile_distribution); auto sn_tile = make_static_distributed_tensor(dram_tile_distribution); - // Build windows only if scales are provided + // Build windows only if non-scalar scales are provided auto scale_m_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { - static_assert( - IsLoadableTile, - "ScaleM must be a loadable tile"); return make_tile_window(scale_m, dram_tile_distribution); } else @@ -503,9 +493,6 @@ struct CShuffleEpilogue auto scale_n_window = [&]() { if constexpr(has_scales && !has_scalar_scales) { - static_assert( - IsLoadableTile, - "ScaleN must be a loadable tile"); return make_tile_window(scale_n, dram_tile_distribution); } else @@ -520,8 +507,8 @@ struct CShuffleEpilogue merge_sequences(sequence{}, c_warp_y_index_zeros), merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); - // If scales provided, load them with identical distribution - if constexpr(has_scales && IsLoadableTile && IsLoadableTile) + // If non-scalar scales provided, load them with identical distribution + if constexpr(has_scales && !has_scalar_scales) { sm_tile = load_tile(scale_m_window); // row scales in permuted layout sn_tile = load_tile(scale_n_window); // col scales in permuted layout @@ -540,7 +527,7 @@ struct CShuffleEpilogue { v = static_cast(v * scale_m * scale_n); } - else if constexpr(has_scales) + else if constexpr(has_scales && !has_scalar_scales) { // same linear index mapping on the permuted distribution const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); @@ -641,9 +628,6 @@ struct CShuffleEpilogue } else if constexpr(has_scales) { - static_assert( - IsLoadableTile, - "ScaleM must be a loadable tile"); return make_tile_window(scale_m, lds_tile.get_tile_distribution()); } else @@ -658,9 +642,6 @@ struct CShuffleEpilogue } else if constexpr(has_scales) { - static_assert( - IsLoadableTile, - "ScaleN must be a loadable tile"); return make_tile_window(scale_n, lds_tile.get_tile_distribution()); } else diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index a924279d52..ab0b310510 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -598,8 +598,8 @@ struct FlatmmKernel CK_TILE_DEVICE void operator()(KernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); // options diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index fcd512056d..56865498c0 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -707,8 +707,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_bias = 0; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index a196807b83..b2b00a07e4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -60,12 +60,12 @@ struct FmhaBwdDQDKDVKernel using VGradDataType = ck_tile::remove_cvref_t; using BiasGradDataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; - using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + using FmhaMask = ck_tile::remove_cvref_t; using FmhaDropout = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasDropout = FmhaDropout::IsDropout; @@ -100,8 +100,8 @@ struct FmhaBwdDQDKDVKernel #define _TS_ std::to_string auto pn = [&] () { std::string n; - if (kPadHeadDimQ) n += "d"; - if (kPadHeadDimV) n += "dv"; + if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ); + if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV); return n.empty() ? n : std::string("p") + n; }(); return _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + @@ -117,7 +117,7 @@ struct FmhaBwdDQDKDVKernel ("maxq" + _TS_(kMaxSeqLenQ)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_ @@ -690,7 +690,7 @@ struct FmhaBwdDQDKDVKernel // divide problem const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex(); - const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); + const index_t i_n0 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN0); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; @@ -815,7 +815,7 @@ struct FmhaBwdDQDKDVKernel const auto q_dram = pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); const auto k_dram_naive = make_naive_tensor_view( k_ptr, @@ -826,7 +826,7 @@ struct FmhaBwdDQDKDVKernel const auto k_dram = pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); const auto v_dram = [&]() { const auto v_dram_naive = make_naive_tensor_view( @@ -838,7 +838,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); // lse and d should be fine to read unpaded data as they are not on the reduction dimension @@ -857,7 +857,7 @@ struct FmhaBwdDQDKDVKernel const auto do_dram = pad_tensor_view( do_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); auto q_dram_window = make_tile_window( q_dram, @@ -905,7 +905,7 @@ struct FmhaBwdDQDKDVKernel const auto dq_acc_dram = pad_tensor_view( dq_acc_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); return make_tile_window( dq_acc_dram, make_tuple(number{}, number{}), @@ -1089,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dk_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); auto dv_dram = [&]() { @@ -1103,7 +1103,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dv_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence 0)>{}); }(); auto dk_dram_window = make_tile_window( @@ -1338,7 +1338,7 @@ struct FmhaBwdOGradDotOKernel // divide problem const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0); long_index_t batch_offset_o = 0; long_index_t batch_offset_do = 0; @@ -1618,7 +1618,7 @@ struct FmhaBwdConvertQGradKernel // divide problem const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex(); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0); long_index_t batch_offset_dq = 0; long_index_t batch_offset_dq_acc = 0; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 66f51459af..a82d121d62 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -262,8 +262,8 @@ struct FmhaFwdAppendKVKernel // divide problem const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0); - const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0); + const index_t i_m0 = amd_wave_read_first_lane(i_tile * FmhaPipeline::kM0); + const index_t i_n0 = amd_wave_read_first_lane(i_tile * FmhaPipeline::kN0); const index_t i_cache_batch = [&, i_batch_ = i_batch] { if constexpr(kIsPagedKV) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 58fdad149a..dafe99febe 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -72,12 +72,14 @@ struct FmhaFwdKernel static constexpr std::string_view kPipelineName = FmhaPipeline::name; // clang-format off - template struct t2s; + template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "fp8bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8fp32"; }; // clang-format on CK_TILE_HOST static std::string GetName() @@ -99,7 +101,7 @@ struct FmhaFwdKernel if (kPadHeadDimV) n += "dv"; return n.empty() ? n : std::string("p") + n; }(); return - _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + "_" + (kIsGroupMode ? "group" : "batch") + "_" "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + @@ -291,6 +293,11 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; + + // Optional cumulative sequence length pointers for batch mode + // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // cumulative, length without PAD }; struct FmhaFwdGroupModeKargs @@ -310,6 +317,11 @@ struct FmhaFwdKernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + // Optional cumulative padded sequence starts (including PAD tokens) + // Used solely to compute memory offsets when sequences are physically padded. + const int32_t* seqstart_padded_q_ptr = nullptr; + const int32_t* seqstart_padded_k_ptr = nullptr; }; using Kargs = std::conditional_t; @@ -366,7 +378,9 @@ struct FmhaFwdKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -457,6 +471,8 @@ struct FmhaFwdKernel kargs.init_logits_soft_cap(logits_soft_cap); } + kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; return kargs; } @@ -505,7 +521,9 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset, + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -550,7 +568,9 @@ struct FmhaFwdKernel mask_type, p_drop, s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + cu_seqlen_q_ptr, + cu_seqlen_kv_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -598,7 +618,9 @@ struct FmhaFwdKernel ck_tile::index_t mask_type, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset, + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -643,7 +665,9 @@ struct FmhaFwdKernel mask_type, p_drop, s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + cu_seqlen_q_ptr, + cu_seqlen_kv_ptr); } template @@ -686,7 +710,9 @@ struct FmhaFwdKernel float p_drop, bool s_randval, std::variant, std::pair> - drop_seed_offset) + drop_seed_offset, + const void* seqstart_padded_q_ptr = nullptr, + const void* seqstart_padded_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -778,6 +804,8 @@ struct FmhaFwdKernel kargs.min_seqlen_q = min_seqlen_q; } + kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); + kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); return kargs; } @@ -821,7 +849,9 @@ struct FmhaFwdKernel ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset, + const void* seqstart_padded_q_ptr = nullptr, + const void* seqstart_padded_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -861,7 +891,9 @@ struct FmhaFwdKernel min_seqlen_q, p_drop, s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + seqstart_padded_q_ptr, + seqstart_padded_k_ptr); } // std::variant<> can't take in a list initializer, overload for backward compatibility @@ -904,7 +936,9 @@ struct FmhaFwdKernel ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, - const std::tuple& drop_seed_offset) + const std::tuple& drop_seed_offset, + const void* seqstart_padded_q_ptr = nullptr, + const void* seqstart_padded_k_ptr = nullptr) { return MakeKargsImpl( q_ptr, @@ -944,7 +978,9 @@ struct FmhaFwdKernel min_seqlen_q, p_drop, s_randval, - std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset))); + std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + seqstart_padded_q_ptr, + seqstart_padded_k_ptr); } CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, @@ -1060,8 +1096,8 @@ struct FmhaFwdKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; @@ -1073,35 +1109,44 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { - // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + // logical and physical (padded) starts + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + // DRAM base offsets use physical padded starts + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start * kargs.stride_v; + batch_offset_v = key_start_padded * kargs.stride_v; } else { - batch_offset_v = key_start; + batch_offset_v = key_start_padded; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias; + batch_offset_bias = query_start_padded * kargs.stride_bias; } if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + // LSE stays indexed by unpadded starts + batch_offset_lse = query_start_unpadded; } if constexpr(kHasDropout) { - batch_offset_randval = query_start * kargs.stride_randval; + batch_offset_randval = query_start_padded * kargs.stride_randval; } - batch_offset_o = query_start * kargs.stride_o; + batch_offset_o = query_start_padded * kargs.stride_o; - // get real # queries & # keys under group mode + // real logical lengths (exclude PAD) const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0]; @@ -1113,8 +1158,7 @@ struct FmhaFwdKernel } } - // # of required blocks is different in each groups, terminate unnecessary blocks - // earlier + // terminate unnecessary blocks earlier if(kargs.seqlen_q <= i_m0) { return; @@ -1150,6 +1194,18 @@ struct FmhaFwdKernel static_cast(i_batch) * kargs.batch_stride_randval; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer @@ -1548,26 +1604,35 @@ struct FmhaFwdKernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; if constexpr(std::is_same_v) { - batch_offset_v = key_start * kargs.stride_v; + batch_offset_v = key_start_padded * kargs.stride_v; } else { - batch_offset_v = key_start; + // col-major V: offset along seqlen dimension is scalar index + batch_offset_v = key_start_padded; } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias; + batch_offset_bias = query_start_padded * kargs.stride_bias; } - batch_offset_lse = query_start; - batch_offset_o = query_start * kargs.stride_o; + // LSE layout is [nhead, total_seqlen], index by unpadded start + batch_offset_lse = query_start_unpadded; + batch_offset_o = query_start_padded * kargs.stride_o; // get real # queries & # keys under group mode kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch]; @@ -1605,6 +1670,18 @@ struct FmhaFwdKernel batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer @@ -1767,6 +1844,9 @@ struct FmhaFwdKernel make_tuple(number{}, number{}), sequence{}); + constexpr auto kDramTileK = + FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0; + #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD constexpr index_t LDSLayerSize = 256 / sizeof(KDataType); constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim); @@ -1835,32 +1915,36 @@ struct FmhaFwdKernel { const auto k_dram_unmerged = transform_tensor_view( k_dram_pad, - make_tuple( - make_pass_through_transform(height), - make_unmerge_transform(make_tuple( - number{}, - number{}))), + make_tuple(make_pass_through_transform(height), + make_unmerge_transform( + make_tuple(number{}, + number{}, + number{}))), make_tuple(sequence<0>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{})); + make_tuple(sequence<0>{}, sequence<1, 2, 3>{})); const auto k_dram_permuted = transform_tensor_view( k_dram_unmerged, make_tuple( make_xor_transform(make_tuple( - height, - number{})), + height, number{})), + make_pass_through_transform( + number{}), make_pass_through_transform(number{})), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0, 1>{}, sequence<2>{})); + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); return transform_tensor_view( k_dram_permuted, - make_tuple( - make_pass_through_transform(height), - make_merge_transform_v3_division_mod(make_tuple( - number{}, - number{}))), - make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(make_pass_through_transform(height), + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); } }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp index 58ef6ba87e..62ac70db92 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp @@ -880,8 +880,8 @@ struct FmhaFwdPagedKVKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index cf819c4b8d..a6fc0f1471 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -281,8 +281,8 @@ struct FmhaFwdSplitKVCombineKernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_o_acc = 0; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 9293c97a31..80de65ead4 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -589,8 +589,8 @@ struct FmhaFwdSplitKVKernel // divide problem const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; // unused for paged-kvcache diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index c5e5745817..e9115b14df 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -100,6 +100,11 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_o; + + // Optional cumulative sequence length pointers for batch mode + // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; struct FmhaFwdGroupModeKargs @@ -110,6 +115,11 @@ struct FmhaFwdV3Kernel const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; const int32_t* seqlen_k_ptr; + + // Optional cumulative padded sequence starts (including PAD tokens) + // Used solely to compute memory offsets when sequences are physically padded. + const int32_t* seqstart_padded_q_ptr = nullptr; // [batch+1] + const int32_t* seqstart_padded_k_ptr = nullptr; // [batch+1] }; using Kargs = std::conditional_t; @@ -145,7 +155,9 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - ck_tile::index_t remap_opt) + ck_tile::index_t remap_opt, + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -187,6 +199,8 @@ struct FmhaFwdV3Kernel kargs.batch_stride_lse = batch_stride_lse; } + kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; + kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; return kargs; } @@ -217,7 +231,9 @@ struct FmhaFwdV3Kernel ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, - ck_tile::index_t remap_opt) + ck_tile::index_t remap_opt, + const void* seqstart_padded_q_ptr = nullptr, + const void* seqstart_padded_k_ptr = nullptr) { Kargs kargs{{q_ptr, k_ptr, @@ -257,6 +273,8 @@ struct FmhaFwdV3Kernel kargs.nhead_stride_lse = nhead_stride_lse; } + kargs.seqstart_padded_q_ptr = reinterpret_cast(seqstart_padded_q_ptr); + kargs.seqstart_padded_k_ptr = reinterpret_cast(seqstart_padded_k_ptr); return kargs; } @@ -361,8 +379,8 @@ struct FmhaFwdV3Kernel // divide problem const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); - const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); - const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1); long_index_t batch_offset_q = 0; long_index_t batch_offset_k = 0; @@ -373,18 +391,26 @@ struct FmhaFwdV3Kernel if constexpr(kIsGroupMode) { // get starting offset for each batch - const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; + const long_index_t query_start_unpadded = kargs.seqstart_q_ptr[i_batch]; + const long_index_t key_start_unpadded = kargs.seqstart_k_ptr[i_batch]; - batch_offset_q = query_start * kargs.stride_q; - batch_offset_k = key_start * kargs.stride_k; - batch_offset_v = key_start * kargs.stride_v; + const long_index_t query_start_padded = kargs.seqstart_padded_q_ptr + ? kargs.seqstart_padded_q_ptr[i_batch] + : query_start_unpadded; + const long_index_t key_start_padded = kargs.seqstart_padded_k_ptr + ? kargs.seqstart_padded_k_ptr[i_batch] + : key_start_unpadded; + + batch_offset_q = query_start_padded * kargs.stride_q; + batch_offset_k = key_start_padded * kargs.stride_k; + batch_offset_v = key_start_padded * kargs.stride_v; if constexpr(kStoreLSE) { - batch_offset_lse = query_start; + // LSE layout is [nhead, total_seqlen], index by unpadded start + batch_offset_lse = query_start_unpadded; } - batch_offset_o = query_start * kargs.stride_o; + batch_offset_o = query_start_padded * kargs.stride_o; // get real # queries & # keys under group mode const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; @@ -417,6 +443,18 @@ struct FmhaFwdV3Kernel batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; } batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + + // If cumulative seqlen pointers are provided, override per-batch effective lengths + if(kargs.cu_seqlen_q_ptr != nullptr) + { + kargs.seqlen_q = + kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch]; + } + if(kargs.cu_seqlen_kv_ptr != nullptr) + { + kargs.seqlen_k = + kargs.cu_seqlen_kv_ptr[i_batch + 1] - kargs.cu_seqlen_kv_ptr[i_batch]; + } } // for simplicity, batch stride we just modify the pointer diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 5e63fb714a..ea024a0257 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index c402eaeac4..6393f227a2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "kr_ktr_vr_iglp"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp index c3e84df934..abe024ced1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp @@ -14,7 +14,8 @@ namespace ck_tile { template class BlockFmhaBwdDQDKDVPipelineSelector { - static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV; + static constexpr bool has_dpad1 = + Problem::Traits::kPadHeadDimQ == 1 || Problem::Traits::kPadHeadDimV == 1; static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0; public: @@ -24,7 +25,7 @@ class BlockFmhaBwdDQDKDVPipelineSelector std::conditional_t, BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR>, - std::conditional_t, BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP>>; using type = std::conditional_t, // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 41cb4fc306..5cdb4fe1d7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "trload_kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 6d90429407..3d5bfcc76a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -51,8 +51,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV; static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kIsDeterministic = Problem::kIsDeterministic; @@ -62,18 +62,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentK(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK(); static constexpr index_t kAlignmentV = - kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentOGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad(); static constexpr index_t kAlignmentQGrad = 1; static constexpr index_t kAlignmentKGrad = - kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad(); + kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad(); static constexpr index_t kAlignmentVGrad = - kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad(); + kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad(); static constexpr index_t kAlignmentBias = 1; static constexpr const char* name = "trload_kr_ktr_vr"; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index ad9e2959f5..5eac387a66 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -408,8 +408,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<1, 2>, sequence<2, 1>>{}); - if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == - kNPerBlock * kKPerBlock) + if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2 { return dstr; } @@ -457,8 +456,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence<2, 0>>, sequence<1, 2>, // N0 K1 sequence<0, 1>>{}); - if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == - kNPerBlock * kKPerBlock) + if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2 { return dstr; } @@ -507,8 +505,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<1, 2>, sequence<2, 1>>{}); - if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == - kMPerBlock * kKPerBlock) + if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2 { return dstr; } @@ -558,8 +555,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<1, 2>, sequence<2, 1>>{}); - if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == - kMPerBlock * kKPerBlock) + if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2 { return dstr; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 99718a187f..38aff07093 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -57,13 +57,11 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kUseTrLoad = kUseTrLoad_; // attributes from traits - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; - static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); - static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ"); + static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; template ; static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once static_assert(kQLoadOnce == Policy::QLoadOnce); + static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64; static constexpr index_t kBlockSize = Problem::kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index ff1f31edc8..dccb41ba44 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -813,7 +813,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, @@ -903,7 +904,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, tuple, sequence>, diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index cd3893f5cf..59267fa3b1 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -37,6 +37,23 @@ struct TileFmhaTraits static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_; }; +template +struct TileFmhaBwdTraits +{ + static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_; + static constexpr index_t kPadHeadDimV = kPadHeadDimV_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; + + static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1); + static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1); +}; + template (kargs.num_sorted_tiles_ptr)); num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0; @@ -261,7 +261,7 @@ struct FusedMoeGemmKernel { // allocate LDS // __shared__ char smem_ptr[GetSmemSize()]; - IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane( + IndexDataType num_sorted_tiles = amd_wave_read_first_lane( *reinterpret_cast(kargs.num_sorted_tiles_ptr)); constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2; @@ -283,14 +283,14 @@ struct FusedMoeGemmKernel return; const IndexDataType expert_id = - __builtin_amdgcn_readfirstlane(reinterpret_cast( + amd_wave_read_first_lane(reinterpret_cast( kargs.sorted_expert_ids_ptr)[sorted_tile_id]); // index along intermediate_size // index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id * // BlockShape::Block_N0); index_t interm_idx_nr = - __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0); + amd_wave_read_first_lane(intermediate_tile_id * BlockShape::Block_Nr0); const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] const auto sorted_token_id = diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index faeb5cf6b3..28416ec538 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -756,7 +756,7 @@ struct MoeSortingKernel void* smem) const { const index_t tid = static_cast(threadIdx.x); - const index_t wid = __builtin_amdgcn_readfirstlane(tid / get_warp_size()); + const index_t wid = amd_wave_read_first_lane(tid / get_warp_size()); const index_t lid = __lane_id(); constexpr index_t block_size = 256; // blockDim.x; const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor; diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp index 38410721ae..d19f0894b9 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp @@ -184,17 +184,17 @@ struct FusedMoeGemmPipeline_FlatmmUk index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1; index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1; - const IndexDataType expert_id = __builtin_amdgcn_readfirstlane( + const IndexDataType expert_id = amd_wave_read_first_lane( reinterpret_cast(kargs.sorted_expert_ids_ptr)[sorted_tile_id]); index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size; index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size; // nr*kr*w - index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane( + index_t interm_idx_nr0 = amd_wave_read_first_lane( intermediate_tile_id * BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W) - index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane( + index_t interm_idx_kr1 = amd_wave_read_first_lane( intermediate_tile_id * BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W) diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 588d903b25..6f9d53467f 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -169,27 +169,27 @@ struct BatchedGemmKernel CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - const auto i_batch = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + const auto i_batch = amd_wave_read_first_lane(blockIdx.y); + const auto i_splitk = amd_wave_read_first_lane(blockIdx.z); const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(kargs, i_splitk); // options - const auto batch_stride_A = __builtin_amdgcn_readfirstlane(kargs.batch_stride_A); - const auto batch_offset_A = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_A); + const auto batch_stride_A = amd_wave_read_first_lane(kargs.batch_stride_A); + const auto batch_offset_A = amd_wave_read_first_lane(i_batch * batch_stride_A); const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + batch_offset_A + splitk_batch_offset.as_k_split_offset[0]; - const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); - const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); + const auto batch_stride_B = amd_wave_read_first_lane(kargs.batch_stride_B); + const auto batch_offset_B = amd_wave_read_first_lane(i_batch * batch_stride_B); const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + batch_offset_B + splitk_batch_offset.bs_k_split_offset[0]; - const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); - const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); + const auto batch_stride_E = amd_wave_read_first_lane(kargs.batch_stride_E); + const auto batch_offset_C = amd_wave_read_first_lane(i_batch * batch_stride_E); CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp index 3b050e03ed..b4ddc33e8d 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -132,6 +132,10 @@ struct GemmKernelMultiABD static constexpr index_t NumBTensor = BsDataType::size(); static constexpr index_t NumDTensor = DsDataType::size(); + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + using DDataType = remove_cvref_t>; + CK_TILE_HOST static auto GetName() -> const std::string { return UniversalGemmKernel::GetName(); @@ -181,6 +185,14 @@ struct GemmKernelMultiABD { return false; } + // Currently MultiABD kernel doesn't support F8 data type + if(ck_tile::get_device_name() == "gfx950" && + (std::is_same::value || + std::is_same::value || + std::is_same::value)) + { + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index a891d4df55..673f5abc34 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -73,8 +73,8 @@ struct GemmTile2DPartitioner CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple { - const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); - const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy); + const index_t iM = amd_wave_read_first_lane(blockIdx); + const index_t iN = amd_wave_read_first_lane(blockIdy); return make_tuple(iM, iN); } }; @@ -143,8 +143,8 @@ struct GemmTile1DPartitioner { const index_t NBlocks = integer_divide_ceil(N_, NPerBlock); - const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks); - const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks); + const index_t iM = amd_wave_read_first_lane(blockIdx / NBlocks); + const index_t iN = amd_wave_read_first_lane(blockIdx - iM * NBlocks); return make_tuple(iM, iN); } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index df1d6c9e4f..217637d605 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -23,10 +23,13 @@ namespace ck_tile { /// arguments object. It contain all necessary information required to build proper kernel /// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by /// stating all required information like M,N,K sizes and respective strides. + +template struct GroupedGemmHostArgs { CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, const void* b_ptr_, + const std::array& ds_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, @@ -34,15 +37,18 @@ struct GroupedGemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, + const std::array& stride_Ds_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), + ds_ptr(ds_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), + stride_Ds(stride_Ds_), stride_E(stride_E_), k_batch(k_batch_) { @@ -50,6 +56,7 @@ struct GroupedGemmHostArgs const void* a_ptr; const void* b_ptr; + const std::array ds_ptr; union { void* e_ptr; @@ -61,7 +68,7 @@ struct GroupedGemmHostArgs index_t K; index_t stride_A; index_t stride_B; - + const std::array stride_Ds; union { index_t stride_E; @@ -71,20 +78,23 @@ struct GroupedGemmHostArgs index_t k_batch; }; +template struct GemmTransKernelArg { - UniversalGemmKernelArgs<> group_karg; + UniversalGemmKernelArgs<1, 1, NumDTensor> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = delete; - GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) - : group_karg{karg}, block_start{bl_start}, block_end{bl_end} + GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg, + index_t bl_start, + index_t bl_end) + : group_karg{std::move(karg)}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg) - : group_karg{karg}, block_start{0}, block_end{0} + GemmTransKernelArg(UniversalGemmKernelArgs<1, 1, NumDTensor>&& karg) + : group_karg{std::move(karg)}, block_start{0}, block_end{0} { } }; @@ -106,9 +116,12 @@ struct GroupedGemmKernel using CLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, C/E - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + static constexpr index_t NumDTensor_ = DsDataType::size(); /// @brief ALayout and ADataType are expected to be scalars, not a tuple. static_assert( @@ -140,19 +153,21 @@ struct GroupedGemmKernel concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), concat('x', P_::kPadM, P_::kPadN, P_::kPadK), - (UsePersistentKernel ? "Persistent" : "NonPersistent")); + (UsePersistentKernel ? "Persistent" : "NonPersistent"), + (NumDTensor_ == 2 ? "MultiD" : "NoMultiD"), + (GemmPipeline::DoubleSmemBuffer ? "DoubleSmemBuffer" : "SingleSmemBuffer")); // clang-format on } CK_TILE_HOST static auto - GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t + GetWorkSpaceSize(const std::vector>& gemm_descs) -> std::size_t { - return gemm_descs.size() * sizeof(GemmTransKernelArg); + return gemm_descs.size() * sizeof(GemmTransKernelArg); } CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t { - return group_count * sizeof(GemmTransKernelArg); + return group_count * sizeof(GemmTransKernelArg); } CK_TILE_HOST static auto BlockSize() -> dim3 @@ -184,7 +199,8 @@ struct GroupedGemmKernel return dim3(grid_size, 1, 1); } - CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) + CK_TILE_HOST static auto + GridSize(const std::vector>& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -196,9 +212,10 @@ struct GroupedGemmKernel } CK_TILE_HOST static auto - MakeKargs(const std::vector& gemm_descs) -> std::vector + MakeKargs(const std::vector>& gemm_descs) + -> std::vector> { - std::vector gemm_kernel_args_; + std::vector> gemm_kernel_args_; index_t group_count = ck_tile::type_convert(gemm_descs.size()); index_t grid_size = 0; gemm_kernel_args_.reserve(group_count); @@ -217,6 +234,7 @@ struct GroupedGemmKernel const index_t stride_a = gemm_descs[i].stride_A; const index_t stride_b = gemm_descs[i].stride_B; const index_t stride_e = gemm_descs[i].stride_E; + auto stride_ds = gemm_descs[i].stride_Ds; const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch; @@ -225,19 +243,19 @@ struct GroupedGemmKernel grid_size += grid_size_grp; - auto karg = - UniversalGemmKernelArgs<>{{type_convert(gemm_descs[i].a_ptr)}, - {type_convert(gemm_descs[i].b_ptr)}, - {/*ds_ptr*/}, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - {stride_a}, - {stride_b}, - {/*stride_ds*/}, - stride_e, - gemm_descs[i].k_batch}; + auto karg = UniversalGemmKernelArgs<1, 1, NumDTensor_>{ + {type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + {gemm_descs[i].ds_ptr}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + stride_ds, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -245,7 +263,8 @@ struct GroupedGemmKernel return gemm_kernel_args_; } - CK_TILE_HOST static bool IsSupportedArgument(const std::vector& kargs) + CK_TILE_HOST static bool + IsSupportedArgument(const std::vector>& kargs) { for(const auto& karg : kargs) { @@ -262,7 +281,7 @@ struct GroupedGemmKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } - CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs, + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -272,8 +291,8 @@ struct GroupedGemmKernel const auto [iM, iN] = block_idx_2d; - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); @@ -292,8 +311,16 @@ struct GroupedGemmKernel { __shared__ char smem_ptr_1[GetSmemSize()]; - RunGemmWithPipelineSelection2LDS( - a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); + RunGemmWithPipelineSelection2LDS(a_ptr, + b_ptr, + c_ptr, + kargs.ds_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); } else // SingleSmemBuffer { @@ -306,7 +333,7 @@ struct GroupedGemmKernel { Base::RunGemm({a_ptr}, {b_ptr}, - {/*ds_ptr*/}, + kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, @@ -340,7 +367,7 @@ struct GroupedGemmKernel const BDataType* b_ptr, CDataType* c_ptr, void* smem_ptr_0, - const UniversalGemmKernelArgs<>& kargs, + const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -358,8 +385,8 @@ struct GroupedGemmKernel const auto& d_block_window = gemm_tile_windows.at(Base::I2); // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); @@ -396,9 +423,10 @@ struct GroupedGemmKernel RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, + const std::array& ds_ptr, void* __restrict__ smem_ptr_0, void* __restrict__ smem_ptr_1, - const UniversalGemmKernelArgs<>& kargs, + const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -406,7 +434,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); + {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -416,8 +444,8 @@ struct GroupedGemmKernel const auto& d_block_window = gemm_tile_windows.at(Base::I2); // Get hot-loop and tail configuration - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM pipeline with compile-time branching @@ -453,7 +481,7 @@ struct GroupedGemmKernel c_block_window, c_block_tile, d_block_window, smem_ptr_0); } - CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, + CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, index_t block_id, index_t group_count) const { @@ -485,7 +513,7 @@ struct GroupedGemmKernel index_t group_count) const { const index_t block_id = ck_tile::get_block_1d_id(); - const auto gemm_desc_ptr = reinterpret_cast( + const auto gemm_desc_ptr = reinterpret_cast*>( cast_pointer_to_generic_address_space(gemm_descs_const)); const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count); @@ -508,7 +536,7 @@ struct GroupedGemmKernel const index_t group_count) const { const index_t grid_size = ck_tile::get_grid_size(); - const auto gemm_desc_ptr = reinterpret_cast( + const auto gemm_desc_ptr = reinterpret_cast*>( cast_pointer_to_generic_address_space(gemm_descs_const)); index_t block_id = ck_tile::get_block_1d_id(); // initial block_id index_t cum_grid_size = 0; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 5df1f092d7..ad85b5392d 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -271,8 +271,8 @@ struct StreamKKernel uint32_t block_idx = ck_tile::get_block_1d_id(); bool is_padding_block = - __builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks && - block_idx < kargs.tile_partitioner.dp_start_block_idx); + amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks && + block_idx < kargs.tile_partitioner.dp_start_block_idx); // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they // should not partake in the GEMM @@ -289,7 +289,7 @@ struct StreamKKernel { // Determine the number of macro tiles in A and B this WG is resposible for in the // current C macro tile. - uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + uint32_t current_iter_length = amd_wave_read_first_lane( kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end)); // Determine the 1D tile_idx and the iter_offset for this WG. diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8f44108cc4..51ad4e3dd1 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -326,19 +326,19 @@ struct UniversalGemmKernel __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); + const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); static_for<0, NumATensor, 1>{}([&](auto index) { using AiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead); } else if constexpr(std::is_same_v) { as_k_split_offset[index] = - __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]); + amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]); } }); @@ -347,21 +347,21 @@ struct UniversalGemmKernel if constexpr(std::is_same_v) { bs_k_split_offset[index] = - __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]); + amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]); } else if constexpr(std::is_same_v) { - bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead); } }); if(k_id < static_cast(kargs.k_batch - 1)) { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); + splitted_k = amd_wave_read_first_lane(KRead); } else { - splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } } @@ -970,8 +970,8 @@ struct UniversalGemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& as_block_window = gemm_tile_windows.at(I0); @@ -1026,8 +1026,8 @@ struct UniversalGemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& as_block_window = gemm_tile_windows.at(I0); @@ -1052,10 +1052,10 @@ struct UniversalGemmKernel template > CK_TILE_DEVICE void operator()(KernelArgs kargs) const { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto blockId = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); @@ -1126,22 +1126,22 @@ struct UniversalGemmKernel template , typename = void> CK_TILE_DEVICE void operator()(KernelArgs kargs) const { - const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto grid_size = amd_wave_read_first_lane(get_grid_size()); const auto num_tiles = - __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); - const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); - auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + amd_wave_read_first_lane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = amd_wave_read_first_lane(num_tiles * kargs.k_batch); + auto block_id = amd_wave_read_first_lane(get_block_id()); while(block_id < num_work) { // Get the tile index for this block - const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto tile_idx = amd_wave_read_first_lane(block_id % num_tiles); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); // Get the SplitK offset for this block - const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const auto k_batch = amd_wave_read_first_lane(block_id / num_tiles); const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); std::array as_ptr; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 7159eda683..2b0b2e8488 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); __builtin_amdgcn_sched_barrier(0); @@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) + if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); @@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major) + if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); @@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index b362f751c6..d0466bc8b1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -487,7 +487,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 if(HasHotLoop) { // minus 2 because we have ping-pong double buffer. - index_t iCounter = __builtin_amdgcn_readfirstlane(num_loop - 2); + index_t iCounter = amd_wave_read_first_lane(num_loop - 2); do { // ping diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index 474d1a5a21..7263ddd5a1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -178,7 +178,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 index_t warp_id = get_warp_id(); index_t operation_id = - __builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm + amd_wave_read_first_lane(get_warp_id()); // 0 - Memory read, 1 - block-gemm auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock); @@ -336,7 +336,7 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 MemoryOpsStep(warp_id); } - index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop); + index_t num_compute_steps = amd_wave_read_first_lane(num_loop); while(num_compute_steps > 1) { block_sync_lds(); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index c73fa29245..75790afecd 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -100,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -118,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 82bf75a9e3..0c9c816672 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -18,73 +18,64 @@ namespace ck_tile { namespace detail { // Helper templates for safe type extraction -template +template struct get_aq_layout_or { using type = Default; }; template - requires requires { typename T::AQLayout; } -struct get_aq_layout_or +struct get_aq_layout_or> { using type = typename T::AQLayout; }; -template +template struct get_bq_layout_or { using type = Default; }; template - requires requires { typename T::BQLayout; } -struct get_bq_layout_or +struct get_bq_layout_or> { using type = typename T::BQLayout; }; -template +template struct get_aq_data_type_or { using type = Default; }; template - requires requires { typename T::AQDataType; } -struct get_aq_data_type_or +struct get_aq_data_type_or> { using type = typename T::AQDataType; }; -template +template struct get_bq_data_type_or { using type = Default; }; template - requires requires { typename T::BQDataType; } -struct get_bq_data_type_or +struct get_bq_data_type_or> { using type = typename T::BQDataType; }; -template -concept HasStaticPreshuffleQuant = requires { - { T::PreshuffleQuant } -> std::convertible_to; -}; - -template +template struct is_quantpreshuffle_enabled { static constexpr bool value = false; }; -template -struct is_quantpreshuffle_enabled +template +struct is_quantpreshuffle_enabled { - static constexpr auto value = T::PreshuffleQuant; + static constexpr bool value = T::PreshuffleQuant; }; } // namespace detail @@ -270,34 +261,34 @@ struct QuantGemmKernel const std::size_t k_id = blockIdx.z) { constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); + const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); if constexpr(std::is_same_v) { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + a_k_split_offset = amd_wave_read_first_lane(k_id * KRead); } else if constexpr(std::is_same_v) { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); + a_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_A); } if constexpr(std::is_same_v) { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); + b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); } if(k_id < static_cast(kargs.k_batch - 1)) { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); + splitted_k = amd_wave_read_first_lane(KRead); } else { - splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } } @@ -918,8 +909,8 @@ struct QuantGemmKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -981,10 +972,10 @@ struct QuantGemmKernel CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto blockId = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); // options diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 07c45117e2..39c8e406b7 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -305,8 +305,8 @@ struct QuantGroupedGemmKernel { const auto [iM, iN] = block_idx_2d; - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index e97eeffb9b..3b5bff03d4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -16,7 +16,7 @@ enum struct QuantType : std::uint16_t TensorQuant = 3 }; -std::string quant_type_to_string(QuantType quant_type) +inline std::string quant_type_to_string(QuantType quant_type) { switch(quant_type) { diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 15e697afdf..e68a510a0c 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -840,7 +840,7 @@ struct GroupedConvolutionBackwardDataKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum( + const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum( gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1))); // Run GEMM cooperatively by whole workgroup. @@ -891,7 +891,7 @@ struct GroupedConvolutionBackwardDataKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = __builtin_amdgcn_readfirstlane( + const index_t num_loop = amd_wave_read_first_lane( TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1))); // Run GEMM cooperatively by whole workgroup. @@ -936,7 +936,7 @@ struct GroupedConvolutionBackwardDataKernel CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const { - const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); const index_t group_id = FindGroupId(kargs, blockIdX); const auto [iM, iN] = OffsettedTile1DPartitioner::GetOffsetedTileIndex( @@ -944,13 +944,13 @@ struct GroupedConvolutionBackwardDataKernel kargs.c_grid_descs_m_n[group_id].get_length(I0), kargs.c_grid_descs_m_n[group_id].get_length(I1)); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); - const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); - const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); + const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); // options // conv_bwd_data = Out * Weight = In diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 7bb3fedaf6..b85660aea3 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -423,22 +423,20 @@ struct GroupedConvolutionBackwardWeightKernel __device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = - __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1); + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); + const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1); - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + a_k_split_offset = amd_wave_read_first_lane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); if(k_id < static_cast(kargs.k_batch - 1)) { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); + splitted_k = amd_wave_read_first_lane(KRead); } else { - splitted_k = - __builtin_amdgcn_readfirstlane(kargs.GemmK - KRead * (kargs.k_batch - 1)); + splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1)); } } @@ -805,22 +803,22 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const { - const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t num_loop = __builtin_amdgcn_readfirstlane( + const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); + const index_t num_loop = amd_wave_read_first_lane( ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock)); const index_t i_k = - __builtin_amdgcn_readfirstlane(blockIdZ * num_loop * TilePartitioner::KPerBlock); + amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock); - const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); - const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); - const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); + const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); // options // conv_bwd_weight = Out * In = Weight diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index d1eacd60cd..0363782d33 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -752,8 +752,7 @@ struct GroupedConvolutionForwardKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -802,8 +801,7 @@ struct GroupedConvolutionForwardKernel const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const index_t num_loop = - __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK)); + const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK)); // Run GEMM cooperatively by whole workgroup. const auto& a_block_window = gemm_tile_windows.at(I0); @@ -822,22 +820,22 @@ struct GroupedConvolutionForwardKernel CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const { - const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); const auto [iM, iN] = TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); - const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); - const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); - const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); + const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); // Split-N handling: Get which split this workgroup handles - const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); // Calculate batch offset for this split - const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split); + const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split); // Calculate memory offsets for this split const long_index_t input_batch_offset = static_cast(batch_offset) * diff --git a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp index eb54807d88..bc20057e7a 100644 --- a/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp +++ b/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp @@ -175,9 +175,9 @@ struct ImageToColumn { const auto [M, K] = CalculateMKDims(kargs); - const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock); - const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t iM = amd_wave_read_first_lane(blockIdx.x * kMPerBlock); + const index_t iK = amd_wave_read_first_lane(blockIdx.y * kKPerBlock); + const index_t iBatch = amd_wave_read_first_lane(blockIdx.z); const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0]; const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0]; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index 0de1ada87c..422950b143 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -99,7 +99,7 @@ struct Layernorm2dFwdPipelineTwoPass // Problem::BlockShape static constexpr index_t Block_N = Problem::BlockShape::Block_N; index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N)); // total number of count assume current iter have no pad(only last iter has pad) constexpr index_t count_per_iter = @@ -119,7 +119,7 @@ struct Layernorm2dFwdPipelineTwoPass auto mean = block_norm_reduce.template MakeMeanVarBlockTile(); auto var = block_norm_reduce.template MakeMeanVarBlockTile(); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto x = load_tile(x_window); auto x_resi = load_tile(x_residual_window); @@ -197,7 +197,7 @@ struct Layernorm2dFwdPipelineTwoPass move_tile_window(y_window, {0, stride_to_right_most_window}); // layernorm computation - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto acc = make_static_distributed_tensor( decltype(load_tile(x_window))::get_tile_distribution()); diff --git a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp index 92a71a42c8..83a22aaded 100644 --- a/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp +++ b/include/ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp @@ -156,7 +156,7 @@ struct Reduce const auto merged_reduce_len = transformed_x_tensor.get_tensor_descriptor().get_lengths().at(number<1>{}); index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(merged_reduce_len, S::Block_N)); + amd_wave_read_first_lane(integer_divide_ceil(merged_reduce_len, S::Block_N)); auto block_reduce2d = Policy::template GetBlockReduce2d(); auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); @@ -167,7 +167,7 @@ struct Reduce auto y_compute = block_reduce2d.template MakeYBlockTile(); set_tile(y_compute, reduce_func.template GetIdentityValue()); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { const auto x = load_tile(x_window); block_reduce2d(x, y_compute, reduce_func); diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp index d01f37879a..ca3cdc37c4 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -82,7 +82,7 @@ struct Rmsnorm2dFwdPipelineTwoPass // Problem::BlockShape static constexpr index_t Block_N = Problem::BlockShape::Block_N; index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N)); auto reduce_square_sum_func = ReduceOp::SquareAdd{}; auto reduce_sum_func = ReduceOp::Add{}; @@ -95,7 +95,7 @@ struct Rmsnorm2dFwdPipelineTwoPass auto square_sum = block_reduce2d.template MakeYBlockTile(); set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto x = load_tile(x_window); auto x_resi = load_tile(x_residual_window); @@ -151,7 +151,7 @@ struct Rmsnorm2dFwdPipelineTwoPass move_tile_window(y_window, {0, stride_to_right_most_window}); // rmsnorm computation - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { auto acc = make_static_distributed_tensor( decltype(load_tile(x_window))::get_tile_distribution()); diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index 2553b19fd8..f6c7c0753a 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -138,7 +138,7 @@ struct MoeSmoothquant const index_t i_topk = blockIdx.x; const index_t i_token = blockIdx.y * Block_M; const index_t i_token_in_thrd = - __builtin_amdgcn_readfirstlane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N); + amd_wave_read_first_lane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N); const index_t i_expert = reinterpret_cast( kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk]; diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp index ba9c6374f1..8b0a7274ed 100644 --- a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp @@ -57,7 +57,7 @@ struct SmoothquantPipelineTwoPass static constexpr index_t Block_N = Problem::BlockShape::Block_N; index_t num_n_tile_iteration = - __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + amd_wave_read_first_lane(integer_divide_ceil(row_size, Block_N)); auto reduce_absmax_func = ReduceOp::AbsMax{}; auto reduce_absmax3_func = [](auto acc_, auto v_0_, auto v_1_) { @@ -77,7 +77,7 @@ struct SmoothquantPipelineTwoPass auto absmax = block_reduce2d.template MakeYBlockTile(); set_tile(absmax, reduce_absmax_func.GetIdentityValue()); - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { const auto x = load_tile(x_window); const auto smscale = load_tile(smscale_window); @@ -121,7 +121,7 @@ struct SmoothquantPipelineTwoPass move_tile_window(qy_window, {0, stride_to_right_most_window}); // recompute y and quantize y to qy - for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + for(int iN = amd_wave_read_first_lane(0); iN < num_n_tile_iteration; ++iN) { const auto x = load_tile(x_window); const auto smscale = load_tile(smscale_window); diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index 277049f6b0..e8727ea065 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -96,9 +96,9 @@ struct TopkSoftmaxKernel if(block_row_id > kargs.num_rows) return; - index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input); - index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output); - index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id); + index_t block_os_inp = amd_wave_read_first_lane(block_row_id * kargs.stride_input); + index_t block_os_out = amd_wave_read_first_lane(block_row_id * kargs.stride_output); + index_t num_rows_rem = amd_wave_read_first_lane(kargs.num_rows - block_row_id); const auto input_window = [&]() { const InputType* p_input = diff --git a/include/ck_tile/utility/json_dump.hpp b/include/ck_tile/utility/json_dump.hpp index d7c96d77b8..26af906ed0 100644 --- a/include/ck_tile/utility/json_dump.hpp +++ b/include/ck_tile/utility/json_dump.hpp @@ -1,10 +1,10 @@ +#ifdef CK_ENABLE_JSON_DUMP #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wzero-as-null-pointer-constant" #include "rapidjson/writer.h" #include "rapidjson/stringbuffer.h" #include "rapidjson/document.h" #include "rapidjson/rapidjson.h" -// #include #pragma GCC diagnostic pop #define START_JSON_DUMP_FILE(file_name) \ @@ -76,6 +76,18 @@ static void add_perf_to_json(rapidjson::Writer& writer, writer.EndArray(); } +#else +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-local-typedef" +#define START_JSON_DUMP_FILE(file_name) +#define END_JSON_DUMP_FILE() \ + std::cout << "JSON dump disabled, To enable, set CK_ENABLE_JSON_DUMP cmake option" << std::endl; + +#define ADD_KEY_VALUE(key, value) +#define ADD_PERF_TO_JSON(_time, tflops, gbytes) +#endif + // Helper traits to check for static member existence template struct has_warp_tile_members : std::false_type @@ -698,3 +710,7 @@ void dump_fmha_bwd_json_results(const std::string& json_filename, ADD_PERF_TO_JSON(ave_time, tflops, gb_per_sec) END_JSON_DUMP_FILE(); } + +#ifndef CK_ENABLE_JSON_DUMP +#pragma GCC diagnostic pop +#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 59dfd76ede..d9c6cc5027 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -172,26 +172,26 @@ struct ReferenceMoeGemm : public device::BaseOperator if constexpr(ActivationType == 1) { - v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t, 0); if constexpr(is_same_v) { v_c_up *= 16; v_c *= 16; } tensor_operation::element_wise::Silu{}(v_c, v_c); - v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t, 0); arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; } else if constexpr(ActivationType == 0) { - v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t, 0); if constexpr(is_same_v) { v_c_up *= 16; v_c *= 16; } tensor_operation::element_wise::Gelu{}(v_c, v_c); - v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t, 0); arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; } } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 58e4adfdfa..33239c94ec 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -144,8 +144,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } CDataType v_c{0}; - D0DataType v_d0 = arg.d0_(t, topk_id); // a - D0DataType v_d1 = arg.d1_(e, n); // b + D0DataType v_d0 = arg.d0_.mDesc.GetNumOfDimension() == 3 + ? arg.d0_(t, topk_id, 0) + : arg.d0_(t, topk_id); // a + + D0DataType v_d1 = arg.d1_(e, n); // b if constexpr(MulRoutedWeight) { arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 9aeca39718..ec1b379ead 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -48,6 +48,9 @@ using BF16_Tuple = ck::Tuple; using F32_F32_Tuple = ck::Tuple; +// Generic layouts +using Bypass = ck::tensor_layout::BypassLayoutVerification; + // GEMM layout using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp index 6e2950180d..3ebfdfa0d3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,11 +17,229 @@ namespace tensor_operation { namespace device { namespace instance { -using Multiply = ck::tensor_operation::element_wise::Multiply; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; #ifdef CK_ENABLE_INT8 + +#ifdef CK_USE_WMMA +// RRR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// RCR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// CRR +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple<>, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>>>& instances); + +// Multiply +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>>>& instances); + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ck::Tuple, + Row, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>>& instances); + +#endif + +#ifdef CK_USE_XDL // RRR void add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( std::vector, @@ -198,7 +416,7 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( std::vector, ck::Tuple, - ck::Tuple, + ck::Tuple, Row, ck::Tuple, ck::Tuple, @@ -233,10 +451,88 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( PassThrough, PassThrough, Multiply>>>& instances); - +#endif #endif // GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -300,6 +597,27 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances(op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA #endif return op_ptrs; @@ -307,6 +625,81 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_bias_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances(op_ptrs); + } + } +#endif + +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -372,11 +766,107 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA +#endif + return op_ptrs; } }; // GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_gelu_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances(op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -442,11 +933,106 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif return op_ptrs; } }; // GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_km_kn_mn_v1_instances(op_ptrs); + } + + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances(op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -511,13 +1098,95 @@ struct DeviceOperationInstanceFactory< } } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif return op_ptrs; } }; // Multiply // GEMM + Add + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -568,6 +1238,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -575,6 +1266,67 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Add +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -625,6 +1378,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -632,6 +1406,68 @@ struct DeviceOperationInstanceFactory< }; // GEMM + Gelu +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -682,6 +1519,27 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif #endif return op_ptrs; @@ -689,6 +1547,67 @@ struct DeviceOperationInstanceFactory< }; // GEMM +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleABDSplitK> +{ + using DeviceOp = DeviceGemmMultipleABDSplitK; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif + +#ifdef CK_USE_WMMA + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + if constexpr(is_same_v> && + is_same_v> && + is_same_v> && is_same_v) + { + add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances(op_ptrs); + } + } +#endif +#endif + + return op_ptrs; + } +}; + template > op_ptrs; #ifdef CK_ENABLE_INT8 +#ifdef CK_USE_XDL if constexpr(is_same_v> && is_same_v> && is_same_v> && is_same_v) @@ -740,6 +1660,28 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_USE_WMMA + using Wrapper = DeviceGemmMultipleABDSplitKWrapper; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif +#endif + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp index 1c3bfef8ce..416e64b534 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -205,6 +206,27 @@ using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_comp_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, TF32, TF32> + // clang-format on + >; + // double rate mfma instances on gfx950 template ; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_generic_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -99,6 +100,27 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -64,7 +65,7 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // Latency friendly + // Latency friendly DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, @@ -163,6 +164,41 @@ using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_mem_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, TF32, TF32>, + // Memory friendly + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -142,6 +143,27 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format on >; +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Instances with NumGroupsPerBatch > 1 + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, TF32, TF32, LoopScheduler::Default, 32> + // clang-format on + >; + template using S = ck::Sequence; @@ -139,6 +140,40 @@ using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_scale_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32> + // clang-format on + >; + template using S = ck::Sequence; @@ -89,7 +90,7 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - // instances for small conv.K and conv.C + // instances for small conv.K and conv.C DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4>, @@ -97,6 +98,27 @@ using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_scaleadd_ab_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,ELayout, ck::Tuple, ck::Tuple, F32, F32, ck::Tuple<>, F32, ScaleAdd, ScaleAdd, PassThrough, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> + // clang-format on + >; + template && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances(op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs); + if constexpr(is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 11e827878c..e41e1b833b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -127,24 +127,44 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); + } } + #endif } // layout NDHWGC/GKZYXC/NDHWGK @@ -197,32 +217,44 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); - } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index 045d1623cf..4678ab6c66 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instance PassThrough, AddClamp>>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp index c8375da6e1..08bea2ce45 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp @@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc PassThrough, PassThrough, Bilinear>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -137,8 +153,16 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp index c4fbbf1d90..f2c62564c3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp @@ -125,23 +125,44 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( - op_ptrs); + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( + op_ptrs); + } } + #endif } // layout NDHWGC/GKZYXC/NDHWGK @@ -193,30 +214,42 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v) { - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( - op_ptrs); - } - - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( - op_ptrs); + static_assert(is_same_v, + "Error: AComputeType and BComputeType should be the same"); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( + op_ptrs); + } } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index b0061b966d..c0c3007651 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -480,6 +480,22 @@ void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, Clamp>>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances( std::vector>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index b830bdce71..91221c2c0c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -111,6 +111,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -281,6 +296,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index 00351ceefd..ac7a773aff 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -169,6 +184,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index bd44116057..68cbc56b41 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -55,6 +55,21 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -169,6 +184,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp index c4bc1da57e..d11c80babf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp @@ -68,6 +68,22 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough, Scale>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale, + TF32, + TF32>>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -137,7 +153,16 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + if constexpr(is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } + else + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } } #endif #ifdef CK_ENABLE_FP16 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index af6041bbc5..a59fcd9d6e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -211,6 +211,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc index 5f35ab5a4b..e67d71f8ab 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc @@ -55,6 +55,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan PassThrough, PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + #endif #ifdef CK_ENABLE_INT8 @@ -120,6 +136,22 @@ void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_ins PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc index 9f54c4b633..eedbd1abd0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -84,6 +84,22 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_insta PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_f32_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); #endif // grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt index 5af7322b1a..5ce585ad81 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/CMakeLists.txt @@ -1,16 +1,26 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_MULTI_ABD_INSTANCES) list(APPEND GEMM_MULTI_ABD_INSTANCES - device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp - ) + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + + device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp +) add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp new file mode 100644 index 0000000000..8d4c45ae82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = BF16; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Row; +using B1Layout = B0Layout; +using D0Layout = Row; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 256, 32, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 32, 16, 16, 256, 8, 8, 16, 16, 1, 1, S<32, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 2>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 64, 16, 32, 256, 8, 8, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..eef450533b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + PassThrough>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp new file mode 100644 index 0000000000..0c2a34fbf8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using I8 = int8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = BF16; +using AsDataType = ck::Tuple; +using B0DataType = I8; +using B1DataType = BF16; +using BsDataType = ck::Tuple; +using AccDataType = F32; +using CShuffleDataType = BF16; +using D0DataType = BF16; +using EDataType = BF16; + +using A0Layout = Row; +using AsLayout = ck::Tuple; +using B0Layout = Col; +using B1Layout = B0Layout; +using BsLayout = ck::Tuple; +using D0Layout = Row; +using ELayout = Row; + +using Multiply = ck::tensor_operation::element_wise::Multiply; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using Add = ck::tensor_operation::element_wise::Add; + +using AElementOp = PassThrough; +using BElementOp = Multiply; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +// using CDEElementOp = AddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +template +using device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //###################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemmPipeSched| BlkGemmPipelineVer| + //###################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| | | + //###################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | | | + //###################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemmMultipleABD_Wmma_CShuffleV3< AsLayout, BsLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..30ab4135d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + Add>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + Add, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..56d30f9ad2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + Multiply, + AddFastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Multiply, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp new file mode 100644 index 0000000000..d4b9054a73 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp @@ -0,0 +1,111 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + AddFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple, + AddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_bias_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + Add>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple, + Add, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple<>, + PassThrough, + GemmMNKPadding, + Interwave>{}); +} + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_gelu_v1_instances( + std::vector, + ELayout, + AsDataType, + BsDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + FastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_nk_mn_comp_instances, + ck::Tuple<>, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..cfeaad1a66 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple<>, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple<>, + EDataType, + AElementOp, + Multiply, + FastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); + + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple<>, + Multiply, + FastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..fe36c30e75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + Multiply>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + Multiply, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..69b0e6ff0b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAdd>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAdd, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..a779f27f62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyAddFastGelu>>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAddFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances(instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyAddFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp new file mode 100644 index 0000000000..dec51f72aa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_wmma_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_wmma_cshuffle_v3.hpp" + +#include "device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances( + std::vector, + ck::Tuple, + ELayout, + AsDataType, + ck::Tuple, + ck::Tuple, + EDataType, + AElementOp, + PassThrough, + MultiplyFastGelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); + add_device_operation_instances( + instances, + device_gemm_wmma_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances, + ck::Tuple, + ck::Tuple, + ck::Tuple, + PassThrough, + MultiplyFastGelu, + GemmMNKPadding, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 7f3621a2ba..5987b90685 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -9,6 +9,7 @@ set(GROUPED_CONV2D_FWD xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp @@ -28,12 +29,14 @@ set(GROUPED_CONV2D_FWD xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instance.cpp # merged groups # NHWGC, GKYXC, NHWGK xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp # NGCHW, GKCYX, NGKHW xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -44,9 +47,11 @@ set(GROUPED_CONV2D_FWD xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp # NHWGC, GKYXC, NHWGK xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp # NGCHW, GKCYX, NGKHW xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_mem_intra_instance.cpp @@ -61,6 +66,7 @@ set(GROUPED_CONV2D_FWD xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..352aa82d9f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..8143553d54 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..9a81ccbb82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..676e2d4a27 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..5601638e77 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.cpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..5f3f2a2247 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt index c06e4f5953..a801144bfd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -85,6 +85,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in + NUM_SHARDS 2 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor +) + # merged groups # NHWGC, GKYXC, NHWGK @@ -114,6 +124,15 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups +) #mem # NHWGC, GKYXC, NHWGK @@ -143,6 +162,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + # NHWGC, GKYXC, NHWGK set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) @@ -171,6 +200,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + #comp # NHWGC, GKYXC, NHWGK @@ -200,7 +239,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in + NUM_SHARDS 4 + SRC_LIST GROUPED_CONV2D_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in new file mode 100644 index 0000000000..d12ae33a8e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instance.in @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..6073ad94d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in new file mode 100644 index 0000000000..f516770698 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in new file mode 100644 index 0000000000..75aabfaa94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..3d147035db --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instance.in @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances = + std::vector< + std::unique_ptr, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances_shard( + device_grouped_conv2d_fwd_bias_bn_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt index e63ac766b6..41274f8027 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/CMakeLists.txt @@ -21,10 +21,16 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..61b471cb1c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/comp/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..0bf7f8b7b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..b982a92b02 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..d9835d7658 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..43c04443c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/mem/device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddClamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..77905b3f67 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt index 8faed08c05..f0404cd0f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/CMakeLists.txt @@ -21,10 +21,16 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp16_comp_part2_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..9977482f8a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/comp/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_comp_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..a4b16917bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..f4933e62b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/large_tensor/device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..b1e53145e3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..74555cc227 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/mem/device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_intra_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..b004b4f3cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_clamp/xdl/merged_groups/device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_clamp_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_tf32_instances( + std::vector, + NHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<2, + NHWGC, + GKYXC, + Tuple<>, + NHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index d0ae0ad42e..5774db21c9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -20,10 +20,12 @@ set(GROUPED_CONV3D_FWD xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -31,13 +33,16 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp -xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..63ff09234c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..b6c8cd1bdb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/large_tensor/device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..fe6141ac69 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..633123e3c8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.cpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..d4a05792d7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd3x3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index 6a776b4943..b6377ba2b4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -94,6 +94,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 2 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor +) + # merged groups # NDHWGC, GKZYXC, NDHWGK @@ -123,6 +133,15 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 3 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups +) #mem # NDHWGC, GKZYXC, NDHWGK @@ -154,6 +173,15 @@ generate_sharded_instantiations( ) # NDHWGC, GKZYXC, NDHWGK +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances @@ -180,6 +208,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + #comp # NDHWGC, GKZYXC, NDHWGK @@ -210,6 +248,15 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in + NUM_SHARDS 4 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in new file mode 100644 index 0000000000..352b8207b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instance.in @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..74308b1c9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in new file mode 100644 index 0000000000..b87dce8411 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in new file mode 100644 index 0000000000..c1df1e262e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instance.in @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances& + instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..a857b7de4f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& + instances) +{ + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances< + 3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index bcc7020ca9..ef7cc22bc4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -19,10 +19,15 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..4b60dd1b3e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..04d750d2b9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..765719c7b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..0daf28adef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..2988b715e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 436c37fd58..6a4637d6e1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..869c812b50 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 059d22f8d2..0c126b2084 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -19,10 +19,15 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_16x16_instance.cpp xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_instance.cpp + xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp new file mode 100644 index 0000000000..3a99d693f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_comp_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_comp_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..5859576835 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/large_tensor/device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp new file mode 100644 index 0000000000..905da7e1d0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_inter_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp new file mode 100644 index 0000000000..008dd28921 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_mem_intra_instance.cpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_intra_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..66874c5696 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/merged_groups/device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd3x3, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index f36d55d367..47fc2655bb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -3,6 +3,7 @@ set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp new file mode 100644 index 0000000000..5377cc56bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index 02bd562e43..cc394f2535 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -5,18 +5,6 @@ #include "ck/library/utility/host_tensor.hpp" -void HostTensorDescriptor::CalculateStrides() -{ - mStrides.clear(); - mStrides.resize(mLens.size(), 0); - if(mStrides.empty()) - return; - - mStrides.back() = 1; - std::partial_sum( - mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); -} - std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } std::size_t HostTensorDescriptor::GetElementSize() const @@ -57,3 +45,14 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) return os; } + +std::ostream& operator<<(std::ostream& os, HostTensorDescriptor::ChosenLayout tag) +{ + switch(tag) + { + case HostTensorDescriptor::ChosenLayout::Original: os << "Original"; break; + case HostTensorDescriptor::ChosenLayout::RowMajor: os << "RowMajor"; break; + case HostTensorDescriptor::ChosenLayout::ColumnMajor: os << "ColumnMajor"; break; + } + return os; +} diff --git a/profiler/include/profiler/profile_avg_pool2d_bwd_impl.hpp b/profiler/include/profiler/profile_avg_pool2d_bwd_impl.hpp index caf24f016a..7cf0fed74f 100644 --- a/profiler/include/profiler/profile_avg_pool2d_bwd_impl.hpp +++ b/profiler/include/profiler/profile_avg_pool2d_bwd_impl.hpp @@ -82,7 +82,9 @@ bool profile_avg_pool2d_bwd_impl(int do_verification, [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { using namespace ck::literals; - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, + {C_ * H * W, 1_uz, W * C_, C_}, + ck::tensor_layout::convolution::NCHW{}); }; Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo)); diff --git a/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp b/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp index e7e8f7213f..fba8f6f67f 100644 --- a/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp +++ b/profiler/include/profiler/profile_avg_pool3d_bwd_impl.hpp @@ -93,7 +93,8 @@ bool profile_avg_pool3d_bwd_impl(int do_verification, using namespace ck::literals; return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, + ck::tensor_layout::convolution::NDHWC{}); }; Tensor dout_n_c_do_ho_wo(f_host_tensor_descriptor(N, C, Do, Ho, Wo)); diff --git a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp index 22dab31100..4b0b8e5bcb 100644 --- a/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_add_relu_gemm_add_impl.hpp @@ -116,11 +116,13 @@ bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification, if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp index a91191b33d..060fbd70e5 100644 --- a/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_b_scale_impl.hpp @@ -66,11 +66,13 @@ bool profile_batched_gemm_b_scale_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp index be69b67b5c..2f6a50cbd4 100644 --- a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -20,6 +20,10 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + namespace ck { namespace profiler { @@ -107,12 +111,12 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, const int BatchCount = G0 * G1; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Col{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp index 8089f9efc7..a8571d0779 100644 --- a/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_gemm_impl.hpp @@ -110,11 +110,13 @@ bool profile_batched_gemm_gemm_impl(bool do_verification, if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_impl.hpp index 92e06e4a70..79ca7029c6 100644 --- a/profiler/include/profiler/profile_batched_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_impl.hpp @@ -61,11 +61,13 @@ bool profile_batched_gemm_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp index 901fa338d4..cb91d8090d 100644 --- a/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_reduce_impl.hpp @@ -83,11 +83,13 @@ bool profile_batched_gemm_reduce_impl(int do_verification, if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp index 700ada73a1..03fa1b1371 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp @@ -118,11 +118,13 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp index e3c462e21c..2945a4a66d 100644 --- a/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp +++ b/profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp @@ -20,6 +20,9 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + namespace ck { namespace profiler { @@ -101,11 +104,11 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification, const int BatchCount = G0 * G1; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Bypass{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/profiler/include/profiler/profile_contraction_impl.hpp b/profiler/include/profiler/profile_contraction_impl.hpp index 604032a01d..616e824ce1 100644 --- a/profiler/include/profiler/profile_contraction_impl.hpp +++ b/profiler/include/profiler/profile_contraction_impl.hpp @@ -60,19 +60,29 @@ int profile_contraction_impl(ck::index_t do_verification, auto f_host_tensor_descriptor = [](const std::vector& dims01, const std::vector& dims23, - const std::vector& strides) { + const std::vector& strides, + auto layout) { std::vector dims_szt(dims01.begin(), dims01.end()); dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); - std::vector strides_szt(strides.begin(), strides.end()); - return HostTensorDescriptor(dims_szt, strides); + // For ColumnMajor with more than 2 dimensions, the strides are custom-defined, so skip + // verification. + if constexpr(ck::is_same_v) + { + if(strides.size() > 2) + { + return HostTensorDescriptor( + dims_szt, strides, ck::tensor_layout::BypassLayoutVerification{}); + } + } + return HostTensorDescriptor(dims_szt, strides, layout); }; - Tensor a_m_k(f_host_tensor_descriptor(M, K, StridesA)); - Tensor b_n_k(f_host_tensor_descriptor(N, K, StridesB)); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StridesD)); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StridesA, ALayout{})); + Tensor b_n_k(f_host_tensor_descriptor(N, K, StridesB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE, CDELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE, CDELayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StridesD, CDELayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_n_k: " << b_n_k.mDesc << std::endl; @@ -160,7 +170,7 @@ int profile_contraction_impl(ck::index_t do_verification, auto ref_op = ReferenceGemmInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE, CDELayout{})); auto ref_argument = ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op); diff --git a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp index 14182bb7b0..aafb7b260d 100644 --- a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp +++ b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp @@ -100,12 +100,12 @@ static auto create_gemm_desc(const ck::index_t G, const ck::index_t NDoHoWo, con if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { - return HostTensorDescriptor({G, NDoHoWo, CZYX}); + return HostTensorDescriptor({G, NDoHoWo, CZYX}, InputLayout{}); } else if constexpr(std::is_same_v || std::is_same_v || std::is_same_v) { - return HostTensorDescriptor({G, NDoHoWo, CZYX}, {CZYX, CZYX * G, 1}); + return HostTensorDescriptor({G, NDoHoWo, CZYX}, {CZYX, CZYX * G, 1}, InputLayout{}); } else { diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index d68a1065ab..f17516a47d 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -19,7 +19,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -75,10 +74,6 @@ bool profile_gemm_ab_scale_impl(int do_verification, ? ((K + ScaleBlockK - 1) / ScaleBlockK) : ((N + ScaleBlockN - 1) / ScaleBlockN); - ck::utils::validate_gemm_stride(M, K, StrideA, "StrideA"); - ck::utils::validate_gemm_stride(K, N, StrideB, "StrideB"); - ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, (K + ScaleBlockK - 1) / ScaleBlockK, diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp index 46591a3525..a8daf4e787 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -136,19 +136,27 @@ bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, return HostTensorDescriptor({len}, {stride}); }; - auto f_host_tensor_descriptor2d = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor2d = [](std::size_t row, + std::size_t col, + int& stride, + auto layout) { + using namespace ck::literals; - if constexpr(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp index 5d79a98c11..e7f4338ef0 100644 --- a/profiler/include/profiler/profile_gemm_add_relu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_add_relu_impl.hpp @@ -43,19 +43,24 @@ bool profile_gemm_add_relu_impl(int do_verification, int StrideD0, int StrideE) { - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + using namespace ck::literals; - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp index 405a2359c2..b265101f3f 100644 --- a/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_bias_add_reduce_impl.hpp @@ -15,7 +15,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -86,17 +85,14 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 33a889afe7..0921b48842 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -20,7 +20,6 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/validation_common.hpp" namespace ck { namespace profiler { @@ -86,29 +85,30 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, { bool pass = true; - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + using namespace ck::literals; - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; ck::index_t Scale_Stride_AM = ((M + ScaleBlockM - 1) / ScaleBlockM); ck::index_t Scale_Stride_BN = ck::is_same_v ? ((K + ScaleBlockK - 1) / ScaleBlockK) : ((N + ScaleBlockN - 1) / ScaleBlockN); - ck::utils::validate_gemm_stride(M, K, StrideA, "StrideA"); - ck::utils::validate_gemm_stride(K, N, StrideB, "StrideB"); - ck::utils::validate_gemm_stride(M, N, StrideE, "StrideE"); - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM, (K + ScaleBlockK - 1) / ScaleBlockK, diff --git a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp index 3893f8cdc7..0fe8abe242 100644 --- a/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp +++ b/profiler/include/profiler/profile_gemm_fastgelu_impl.hpp @@ -40,19 +40,24 @@ bool profile_gemm_fastgelu_impl(int do_verification, int StrideB, int StrideE) { - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + using namespace ck::literals; - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index fdcb3ad128..93eac048cd 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -24,7 +24,6 @@ #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/utility/fill.hpp" -#include "ck/library/utility/validation_common.hpp" namespace ck { namespace profiler { @@ -57,17 +56,14 @@ int profile_gemm_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp new file mode 100644 index 0000000000..a3c5c6a3ac --- /dev/null +++ b/profiler/include/profiler/profile_gemm_multi_abd_impl.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_multi_abd.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +// this function is also defined in CK but because of the way we use it in +// profile_gemm_multi_impl, it requires the arguments to not be const +template +auto concat_tuple_of_refs(ck::Tuple& tx, ck::Tuple& ty) +{ + return ck::unpack2( + [&](auto&&... zs) { return ck::Tuple{ck::forward(zs)...}; }, + tx, + ty); +} + +template +bool profile_gemm_multi_abd_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + static constexpr index_t NumATensor = AsDataType::Size(); + auto as_m_k = generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + using ALayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + }, + Number{}); + + static constexpr index_t NumBTensor = BsDataType::Size(); + auto bs_k_n = generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + }, + Number{}); + + static constexpr index_t NumDTensor = DsDataType::Size(); + auto ds_m_n = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + using DLayout = remove_cvref_t>; + + return Tensor(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + }, + Number{}); + + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + static_for<0, NumATensor, 1>{}( + [&](auto i) { std::cout << "a" << i.value << "_m_k: " << as_m_k(i).mDesc << std::endl; }); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { std::cout << "b" << i.value << "_k_n: " << bs_k_n(i).mDesc << std::endl; }); + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "d" << i.value << "_m_n: " << ds_m_n(i).mDesc << std::endl; }); + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + + as_m_k(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + + bs_k_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + ds_m_n(i).GenerateTensorValue(GeneratorTensor_2{-5, 5}); + }); + + break; + default: + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + + as_m_k(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + + bs_k_n(i).GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + ds_m_n(i).GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + }); + } + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleABD; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + Tensor c_m_n({M, N}); + + using AComputeType = + typename std::conditional<(NumATensor > 1), + EDataType, + remove_cvref_t>>::type; + + auto get_a_matrix = [&]() -> auto { + // in case of pass through we avoid allocating a new + // tensor and copying values + if constexpr(is_same_v) + { + return as_m_k(Number<0>{}); + } + else + { + Tensor a_m_k({M, K}); + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + // result + auto data_refs1 = ck::tie(a_m_k(m, k)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return as_m_k(Number{})(m, k); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(a_element_op, data_refs); + } + } + return a_m_k; + } + }; + + using BComputeType = + typename std::conditional<(NumBTensor > 1), + EDataType, + remove_cvref_t>>::type; + + auto get_b_matrix = [&]() -> auto { + // in case of pass through we avoid allocating a new + // tensor and copying values + if constexpr(is_same_v) + { + return bs_k_n(Number<0>{}); + } + else + { + Tensor b_k_n({K, N}); + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N; ++n) + { + // result + auto data_refs1 = ck::tie(b_k_n(k, n)); + // inputs + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return bs_k_n(Number{})(k, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(b_element_op, data_refs); + } + } + return b_k_n; + } + }; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + get_a_matrix(), get_b_matrix(), c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + // compulsory + auto data_refs1 = ck::tie(e_m_n_host_result(m, n), c_m_n(m, n)); + // optional (if multiple Ds) + auto data_refs2 = + generate_tie([&](auto i) -> auto& { return ds_m_n(Number{})(m, n); }, + Number{}); + auto data_refs = concat_tuple_of_refs(data_refs1, data_refs2); + unpack(cde_element_op, data_refs); + } + } + } + + std::array as_device_buf; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + as_device_buf[i] = new DeviceMem(sizeof(ADataType) * as_m_k(i).mDesc.GetElementSpaceSize()); + }); + + std::array bs_device_buf; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + bs_device_buf[i] = new DeviceMem(sizeof(BDataType) * bs_k_n(i).mDesc.GetElementSpaceSize()); + }); + + std::array ds_device_buf; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + ds_device_buf[i] = new DeviceMem(sizeof(DDataType) * ds_m_n(i).mDesc.GetElementSpaceSize()); + }); + + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_device_buf[i]->ToDevice(as_m_k(i).mData.data()); }); + + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_device_buf[i]->ToDevice(bs_k_n(i).mData.data()); }); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_device_buf[i]->ToDevice(ds_m_n(i).mData.data()); }); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + std::array as_pointer; + std::array as_stride; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_pointer[i] = as_device_buf[i]->GetDeviceBuffer(); + as_stride[i] = StrideA; + }); + + std::array bs_pointer; + std::array bs_stride; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_pointer[i] = bs_device_buf[i]->GetDeviceBuffer(); + bs_stride[i] = StrideB; + }); + std::array ds_pointer; + std::array ds_stride; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_pointer[i] = ds_device_buf[i]->GetDeviceBuffer(); + ds_stride[i] = StrideD; + }); + + auto argument_ptr = op_ptr->MakeArgumentPointer(as_pointer, + bs_pointer, + ds_pointer, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + as_stride, + bs_stride, + ds_stride, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t sizeADataType = 0; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + sizeADataType = std::max(sizeADataType, sizeof(ADataType)); + }); + std::size_t sizeBDataType = 0; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + sizeBDataType = std::max(sizeBDataType, sizeof(BDataType)); + }); + + std::size_t num_btype = + sizeADataType * M * K + sizeBDataType * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + static_for<0, NumATensor, 1>{}([&](auto i) { delete as_device_buf[i]; }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { delete bs_device_buf[i]; }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { delete ds_device_buf[i]; }); + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp index f9a5a995fe..2711d595d6 100644 --- a/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_add_impl.hpp @@ -46,20 +46,25 @@ bool profile_gemm_multiply_add_impl(int do_verification, int StrideD1, int StrideE) { - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - if(is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + using namespace ck::literals; + if(is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); @@ -117,6 +122,11 @@ bool profile_gemm_multiply_add_impl(int do_verification, const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); + if(op_ptrs.size() == 0) + { + std::cout << "No device operation instances found." << std::endl; + return false; + } std::cout << "found " << op_ptrs.size() << " instances" << std::endl; // run reference diff --git a/profiler/include/profiler/profile_gemm_quantization_impl.hpp b/profiler/include/profiler/profile_gemm_quantization_impl.hpp index a115a41a34..02f374164e 100644 --- a/profiler/include/profiler/profile_gemm_quantization_impl.hpp +++ b/profiler/include/profiler/profile_gemm_quantization_impl.hpp @@ -47,11 +47,11 @@ bool profile_gemm_quantization_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_gemm_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_reduce_impl.hpp index a74d2a01d9..470cc86d1b 100644 --- a/profiler/include/profiler/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_reduce_impl.hpp @@ -15,7 +15,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -81,17 +80,14 @@ bool profile_gemm_reduce_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index 0640e95aba..8032730199 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -19,7 +19,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -55,17 +54,14 @@ bool profile_gemm_splitk_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_streamk_impl.hpp index d24ee1c7ea..f86e7ad447 100644 --- a/profiler/include/profiler/profile_gemm_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_streamk_impl.hpp @@ -19,7 +19,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -52,17 +51,14 @@ bool profile_gemm_streamk_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp b/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp index f4300af8d8..99e24cd205 100644 --- a/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp @@ -65,11 +65,13 @@ bool profile_gemm_universal_batched_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index feb75c9660..bb73c4e3da 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -19,7 +19,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -56,17 +55,14 @@ bool profile_gemm_universal_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index 271bc6ef59..e537cf2770 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -19,7 +19,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -84,17 +83,14 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp index 32d2b38def..554956ee88 100644 --- a/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_reduce_impl.hpp @@ -20,7 +20,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" namespace ck { @@ -58,17 +57,14 @@ bool profile_gemm_universal_reduce_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp index 5c859b830d..035a1b77df 100644 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -21,7 +21,6 @@ #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" -#include "ck/library/utility/validation_common.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp" @@ -60,17 +59,14 @@ bool profile_gemm_universal_streamk_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp index cd6c141219..91ac2a0ab6 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_bnorm_clamp_impl.hpp @@ -32,6 +32,7 @@ using OutElementOp = ck::tensor_operation::element_wise::BiasNormalizeInInferCla using Clamp = ck::tensor_operation::element_wise::Clamp; using Add = ck::tensor_operation::element_wise::Add; +using BaseConv = ck::tensor_layout::convolution::BaseConvolutionLayout; // NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to // just keep such implementation valid. // TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse @@ -42,15 +43,15 @@ auto get_elementwise_desc(ck::index_t G, ck::index_t K) { if constexpr(NDimSpatial == 1) { - return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0}); + return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0}, BaseConv{}); } else if constexpr(NDimSpatial == 2) { - return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0}); + return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0}, BaseConv{}); } else { - return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0}); + return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0}, BaseConv{}); } } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index d0e1cf2611..188d7aa0b0 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -25,6 +25,8 @@ namespace ck { namespace profiler { +using BaseConv = ck::tensor_layout::convolution::BaseConvolutionLayout; + // NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to // just keep such implementation valid. // TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse @@ -35,15 +37,15 @@ auto get_bias_desc(ck::index_t G, ck::index_t K) { if constexpr(NDimSpatial == 1) { - return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0}); + return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0}, BaseConv{}); } else if constexpr(NDimSpatial == 2) { - return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0}); + return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0}, BaseConv{}); } else { - return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0}); + return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0}, BaseConv{}); } } diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index fc2ba5a650..eef5e02911 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -57,11 +57,11 @@ bool profile_grouped_gemm_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; diff --git a/profiler/include/profiler/profile_max_pool2d_bwd_impl.hpp b/profiler/include/profiler/profile_max_pool2d_bwd_impl.hpp index 7a712f21f2..6e3de3a26a 100644 --- a/profiler/include/profiler/profile_max_pool2d_bwd_impl.hpp +++ b/profiler/include/profiler/profile_max_pool2d_bwd_impl.hpp @@ -82,7 +82,9 @@ bool profile_max_pool2d_bwd_impl(int do_verification, [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { using namespace ck::literals; - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, + {C_ * H * W, 1_uz, W * C_, C_}, + ck::tensor_layout::convolution::NCHW{}); }; Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); diff --git a/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp b/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp index 15fb4e9034..407337f827 100644 --- a/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp +++ b/profiler/include/profiler/profile_max_pool3d_bwd_impl.hpp @@ -84,7 +84,8 @@ bool profile_max_pool3d_bwd_impl(int do_verification, using namespace ck::literals; return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, + ck::tensor_layout::convolution::NDHWC{}); }; Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index 186a24501e..9ccbd67783 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -40,10 +40,13 @@ bool profile_permute_scale_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; - std::array, 1> as = {Tensor(lengths_vector, input_strides_vector)}; - Tensor& a = as[0]; - Tensor b(lengths_vector, output_strides_vector); - Tensor host_b(lengths_vector, output_strides_vector); + using ALayout = ck::tensor_layout::BypassLayoutVerification; + using BLayout = ck::tensor_layout::BypassLayoutVerification; + std::array, 1> as = { + Tensor(lengths_vector, input_strides_vector, ALayout{})}; + Tensor& a = as[0]; + Tensor b(lengths_vector, output_strides_vector, BLayout{}); + Tensor host_b(lengths_vector, output_strides_vector, BLayout{}); std::cout << "A: " << a.mDesc << std::endl; std::cout << "B: " << b.mDesc << std::endl; diff --git a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp index 23226a4881..88162b9417 100644 --- a/profiler/include/profiler/profile_pool2d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool2d_fwd_impl.hpp @@ -74,7 +74,9 @@ bool profile_pool2d_fwd_impl(int do_verification, [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { using namespace ck::literals; - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, + {C_ * H * W, 1_uz, W * C_, C_}, + ck::tensor_layout::convolution::NCHW{}); }; Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp index cbdacad53b..412946d558 100644 --- a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -91,7 +91,8 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& using namespace ck::literals; return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, + ck::tensor_layout::convolution::NDHWC{}); }; Tensor in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi)); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 31f684fe75..c31ede2c73 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -96,6 +96,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) + list(APPEND PROFILER_OPS profile_gemm_multi_abd.cpp) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp) @@ -234,6 +235,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1 list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_gemm_add_relu_instance) + list(APPEND DEVICE_INSTANCES device_gemm_multi_abd_instance) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND DEVICE_INSTANCES device_gemm_add_multiply_instance) list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance) diff --git a/profiler/src/profile_gemm_multi_abd.cpp b/profiler/src/profile_gemm_multi_abd.cpp new file mode 100644 index 0000000000..157bcbc977 --- /dev/null +++ b/profiler/src/profile_gemm_multi_abd.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + BF16_I8_BF16_BF16, // 0 +}; + +enum struct GemmElementOp +{ + PASS_THROUGH, // 0 + MULTIPLY, // 1 + ADD, // 2 + FASTGELU, // 3 + ADD_FASTGELU, // 4 + MULTIPLY_ADD, // 5 + MULTIPLY_FASTGELU, // 6 + MULTIPLY_ADD_FASTGELU, // 7 +}; + +#define OP_NAME "gemm_multi_abd" +#define OP_DESC "GEMM_Multiple_ABD" + +int profile_gemm_multi_abd(int argc, char* argv[]) +{ + if(argc != 18) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: bf16@int8/bf16->bf16;)\n"); + printf("arg3: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n"); + printf(" 1: E[m, n] = A[m, k] * B[n, k];\n"); + printf(" 2: E[m, n] = A[k, m] * B[k, n];\n"); + printf(" 3: E[m, n] = A[k, m] * B[n, k])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8: number of As (1)\n"); + printf("arg9: number of Bs (1/2)\n"); + printf("arg10: number of Ds (0/1/2)\n"); + printf("arg11 to 17: M, N, K, StrideA, StrideB, StrideE, StrideD\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int num_as = std::stoi(argv[8]); + const int num_bs = std::stoi(argv[9]); + const int num_ds = std::stoi(argv[10]); + + const int M = std::stoi(argv[11]); + const int N = std::stoi(argv[12]); + const int K = std::stoi(argv[13]); + + const int StrideA = std::stoi(argv[14]); + const int StrideB = std::stoi(argv[15]); + const int StrideE = std::stoi(argv[16]); + const int StrideD = std::stoi(argv[17]); + + using F32 = float; + using BF16 = ck::bhalf_t; + using I8 = int8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using Multiply = ck::tensor_operation::element_wise::Multiply; + using FastGelu = ck::tensor_operation::element_wise::FastGelu; + using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; + + auto profile = [&](auto b_layout, auto b_element_op, auto cde_element_op, auto num_d_tensor) { + using ADataType = BF16; + using B0DataType = I8; + using B1DataType = BF16; + using DDataType = BF16; + using EDataType = BF16; + + using ALayout = Row; + using BLayout = decltype(b_layout); + using DLayout = Row; + using ELayout = Row; + + using AElementOp = PassThrough; + using BElementOp = decltype(b_element_op); + using CDEElementOp = decltype(cde_element_op); + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + constexpr auto NumberDTensor = decltype(num_d_tensor){}; + + // Only num_d_tensor == 0 and 1 are supported + using DsDataType = typename std:: + conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple>::type; + using DsLayout = + typename std::conditional<(NumberDTensor == 0), ck::Tuple<>, ck::Tuple>::type; + + bool pass = ck::profiler::profile_gemm_multi_abd_impl, + ck::Tuple, + F32, + DsDataType, + EDataType, + ck::Tuple, + ck::Tuple, + DsLayout, + ELayout, + AElementOp, + BElementOp, + CDEElementOp>( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD < 0) ? DefaultStrideD : StrideD, + (StrideE < 0) ? DefaultStrideE : StrideE); + + return pass ? 0 : 1; + }; + + // num_as == 1 is only supported + if(data_type != GemmDataType::BF16_I8_BF16_BF16 || num_as != 1) + { + std::cout << "The provided input parameters are not supported" << std::endl; + return 1; + } + + // Supported configurations + if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 1) + { + return profile(Row{}, Multiply{}, AddFastGelu{}, ck::Number<1>{}); + } + else if(layout == GemmMatrixLayout::MK_KN_MN && num_bs == 2 && num_ds == 0) + { + return profile(Row{}, Multiply{}, FastGelu{}, ck::Number<0>{}); + } + else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 1) + { + return profile(Col{}, Multiply{}, AddFastGelu{}, ck::Number<1>{}); + } + else if(layout == GemmMatrixLayout::MK_NK_MN && num_bs == 2 && num_ds == 0) + { + return profile(Col{}, Multiply{}, FastGelu{}, ck::Number<0>{}); + } + + std::cout << "The provided input parameters are not supported" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_multi_abd); diff --git a/profiler/src/profile_gemm_multiply_add.cpp b/profiler/src/profile_gemm_multiply_add.cpp index 98973b2f01..88d3b5256a 100644 --- a/profiler/src/profile_gemm_multiply_add.cpp +++ b/profiler/src/profile_gemm_multiply_add.cpp @@ -92,12 +92,6 @@ int profile_gemm_multiply_add(int argc, char* argv[]) using D1Layout = decltype(d1_layout); using ELayout = decltype(e_layout); - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB = ck::is_same_v ? N : K; - const int DefaultStrideD0 = ck::is_same_v ? N : M; - const int DefaultStrideD1 = ck::is_same_v ? N : M; - const int DefaultStrideE = ck::is_same_v ? N : M; - bool pass = ck::profiler::profile_gemm_multiply_add_impl( - do_verification, - init_method, - do_log, - time_kernel, - M, - N, - K, - (StrideA < 0) ? DefaultStrideA : StrideA, - (StrideB < 0) ? DefaultStrideB : StrideB, - (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, - (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, - (StrideE < 0) ? DefaultStrideE : StrideE); + ELayout>(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE); return pass ? 0 : 1; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cedac568db..df3a03cca8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -243,6 +243,7 @@ add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) +add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_b_scale) diff --git a/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp b/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp index 6c04086e0e..eba461a420 100644 --- a/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp +++ b/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp @@ -56,7 +56,21 @@ class TestBatchedGemmMultiD : public ::testing::Test PassThrough, PassThrough, PassThrough>>( - true, 1, false, 1, M, N, K, K, N, N, M * K, K * N, M * N, BatchCount); + true, // do_verification + 1, // init_method + false, // do_log + 1, // time_kernel, + M, + N, + K, + std::is_same_v ? K : M, // strideA + std::is_same_v ? N : K, // strideB + std::is_same_v ? N : M, // strideC + // BatchStrideA BatchStrideB, BatchStrideC + M * K, + K * N, + M * N, + BatchCount); EXPECT_TRUE(pass); } }; diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index b08f0d8316..b92888b1f1 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm_preshuffle) +add_subdirectory(grouped_gemm_multi_d) add_subdirectory(gemm_multi_d) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_streamk) diff --git a/test/ck_tile/fmha/test_fmha_bwd.inc b/test/ck_tile/fmha/test_fmha_bwd.inc index 1ad321ec99..704b5c7bf7 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.inc +++ b/test/ck_tile/fmha/test_fmha_bwd.inc @@ -111,6 +111,9 @@ class HDimPadding INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd, HDimPadding, Combine(Values(std::tuple{24, 48}, + std::tuple{48, 48}, + std::tuple{72, 72}, + std::tuple{96, 96}, std::tuple{120, 160}, std::tuple{256, 108}, std::tuple{40, 64}), diff --git a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp index cd143e8e83..077e45a10d 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_bf16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp index 4bb1e04ad0..86621b0494 100644 --- a/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd_fp16.cpp @@ -16,6 +16,6 @@ const auto HDimValues = const auto ModeValues = Values(mode_enum::batch, mode_enum::group); -constexpr std::string init_method = "uf"; +constexpr auto init_method = "uf"; #include "test_fmha_bwd.inc" diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index 08abd3358d..9497122594 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -98,7 +98,10 @@ TEST_P(AllLong, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -121,6 +124,141 @@ TEST_P(AllLong, Test) CHECK_RESULT(result); } +// --------------------------------------------------------------- +// Negative tests: padding not supported with appendkv/splitkv/pagedkv +// --------------------------------------------------------------- + +#if CK_TILE_FMHA_FWD_APPENDKV_API +TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail) +{ + // batch mode effective lengths simulate padding + auto result = fmha_fwd_run( + mode_enum::batch, + 2, // batch + 4, // nhead + -1, // nhead_k + {128}, // seqlen_qs + {128}, // seqlen_ks + 64, // hdim_q + 64, // hdim_v + 32, // seqlen_knew -> triggers appendkv + {}, // seqlen_qpads + {}, // seqlen_kpads + {100, 120}, // q_eff_lens_per_batch + {90, 110}, // kv_eff_lens_per_batch + 0, // rotary_dim + true, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, + def_lse, + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + "0", // mask + squant, + true, // is_rotary_interleaved + 1, // num_splits + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 0, + stream_config); + ASSERT_EQ(result, fwd_result::invalid_args); +} +#endif + +#if CK_TILE_FMHA_FWD_SPLITKV_API +TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail) +{ + // group mode physical padding + auto result = fmha_fwd_run( + mode_enum::group, + 2, // batch + 4, // nhead + -1, // nhead_k + {96, 120}, // seqlen_qs logical + {96, 120}, // seqlen_ks logical + 64, // hdim_q + 64, // hdim_v + 0, // seqlen_knew + {128, 128}, // seqlen_qpads + {128, 128}, // seqlen_kpads + {}, // q_eff + {}, // kv_eff + 0, // rotary_dim + true, // i_perm + true, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, + def_lse, + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias + 0.0f, + 0, + 0, + false, + "0", + squant, + true, + 2, // num_splits (>1 triggers splitkv) + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 0, + stream_config); + ASSERT_EQ(result, fwd_result::invalid_args); +} +#endif + +#if CK_TILE_FMHA_FWD_PAGEDKV_API +TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail) +{ + auto result = fmha_fwd_run( + mode_enum::group, + 2, + 4, + -1, + {80, 100}, + {80, 100}, + 64, + 64, + 0, // seqlen_knew + {96, 128}, // seqlen_qpads + {96, 128}, // seqlen_kpads + {}, + {}, + 0, + true, + true, + 0, + 0, + def_is_v_rowmajor, + def_lse, + 128, // page_block_size triggers pagedkv + false, + "n", + 0.0f, + 0, + 0, + false, + "0", + squant, + true, + 1, + init_method, + static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), + 0, + stream_config); + ASSERT_EQ(result, fwd_result::invalid_args); +} +#endif + class HDimPadding : public TestWithParam, bool, @@ -160,7 +298,10 @@ TEST_P(HDimPadding, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {seqlen_kpad}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim perm, // i_perm perm, // o_perm @@ -217,7 +358,10 @@ TEST_P(ElementwiseBias, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -273,7 +417,10 @@ TEST_P(Alibi, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim true, // i_perm true, // o_perm @@ -331,7 +478,10 @@ TEST_P(Dropout, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim false, // i_perm false, // o_perm @@ -391,7 +541,10 @@ TEST_P(PagedKV, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -457,7 +610,10 @@ TEST_P(SplitKV, Test) hdim_q, hdim_v, 0, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm false, // o_perm @@ -529,7 +685,10 @@ TEST_P(AppendKV, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch 0, // rotary_dim i_perm, // i_perm true, // o_perm @@ -599,7 +758,10 @@ TEST_P(AppendKVRoPE, Test) hdim_q, hdim_v, seqlen_knew, // seqlen_knew + {-1}, // seqlen_qpads {-1}, // seqlen_kpads + {}, // q_eff_lens_per_batch + {}, // kv_eff_lens_per_batch rotary_dim, // rotary_dim i_perm, // i_perm true, // o_perm @@ -623,3 +785,294 @@ TEST_P(AppendKVRoPE, Test) } #endif // CK_TILE_FMHA_FWD_APPENDKV_API + +// --------------------------------------------------------------- +// Parameterized padding tests (batch & group) using Combine+Values +// --------------------------------------------------------------- + +using PaddingParam = std::tuple, // seqlen_qs (logical) + std::vector, // seqlen_ks (logical) + std::vector, // seqlen_qpads (physical padded lengths) + std::vector, // seqlen_kpads (physical padded lengths) + std::vector, // q_eff_lens + std::vector, // kv_eff_lens + bool, // i_perm + bool, // o_perm + std::string>; // mask_str + +// Ensure headers for containers / algorithms used in padding param builder. +#include +#include +#include +#include + +class PaddingCases : public TestWithParam +{ +}; + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PaddingCases); + +// Build padding test params programmatically to enforce constraints +static std::vector BuildPaddingParams() +{ + std::vector params; + + // mask variants to cover + const std::vector mask_variants{"0", "t:50,64", "b:32,40"}; + const std::vector mask_variants_reduced{"0", "t:50,64"}; // used for trimmed sets + + // Representative ratio pairs (q_ratio, k_ratio) to avoid explosion + const std::vector> ratio_pairs_full{ + {1.0, 1.0}, // both full + {1.0, 0.5}, // q full, k half + {0.5, 1.0}, // q half, k full + }; + const std::vector> ratio_pairs_reduced{{1.0, 1.0}, {0.5, 1.0}}; + + // candidate physical seqlens for batch mode (single value) & for group mode (per batch) + const std::vector physical_lengths_full{64, 128, 256}; + const std::vector physical_lengths_reduced{64}; + + // batch sizes to sample + const std::vector batch_sizes{1, 4}; + // -------------------------------------------------------------------- + // Head configuration space (cover MHA, GQA, MQA) + // - Standard MHA: nhead_k == -1 (treated internally as nhead) + // - GQA: nhead_k > 0 and nhead % nhead_k == 0, nhead_k < nhead + // - MQA: nhead_k == 1 + // We choose (9, -1), (9, 3), (9, 1) so that divisibility holds. Full + // combinatorics only applied to the first (standard) configuration to + // avoid test explosion. + // -------------------------------------------------------------------- + struct HeadCfg + { + int nhead; + int nhead_k; // -1 for standard; else must divide nhead + bool full; // whether to use full coverage sets + }; + const std::vector head_cfgs = { + {9, -1, true}, // MHA full + {9, 3, false}, // GQA reduced (nhead/nhead_k=3) + {9, 1, false} // MQA reduced + }; + + // Helper to clamp and ensure >=1 + auto logical_len = [](int physical, double ratio) { + int v = static_cast(std::round(physical * ratio)); + v = std::max(1, std::min(v, physical)); + return v; + }; + // Iterate over head configurations + for(const auto& hc : head_cfgs) + { + const auto& ratio_pairs = hc.full ? ratio_pairs_full : ratio_pairs_reduced; + const auto& phys_lengths_batch = hc.full ? physical_lengths_full : physical_lengths_reduced; + const auto& phys_lengths_group_q = phys_lengths_batch; // reuse + const auto& phys_lengths_group_k = phys_lengths_batch; // reuse + const auto& masks = hc.full ? mask_variants : mask_variants_reduced; + + // ----------------- + // Batch mode params (effective lengths only) + // ----------------- + for(int b : batch_sizes) + { + for(int phys_qkv : phys_lengths_batch) + { + for(const auto& rkpair : ratio_pairs) + { + double rq = rkpair.first; + double rk = rkpair.second; + std::vector q_eff(b), kv_eff(b); + int log_q = logical_len(phys_qkv, rq); + int log_k = logical_len(phys_qkv, rk); + for(int i = 0; i < b; ++i) + { + q_eff[i] = log_q; + kv_eff[i] = log_k; + } + for(const auto& mask : masks) + { + params.emplace_back(PaddingParam{mode_enum::batch, + b, + hc.nhead, + hc.nhead_k, + {phys_qkv}, // seqlen_qs + {phys_qkv}, // seqlen_ks + {}, // seqlen_qpads + {}, // seqlen_kpads + q_eff, + kv_eff, + true, + true, + mask}); + } + } + // Single-token logical length case (both q & k = 1) + for(const auto& mask : masks) + { + std::vector q_eff(b, 1), kv_eff(b, 1); + params.emplace_back(PaddingParam{mode_enum::batch, + b, + hc.nhead, + hc.nhead_k, + {phys_qkv}, + {phys_qkv}, + {}, + {}, + q_eff, + kv_eff, + true, + true, + mask}); + } + } + } + + // ----------------- + // Group mode params (physical padding + logical variants) + // ----------------- + for(int b : batch_sizes) + { + for(int phys_q : phys_lengths_group_q) + { + for(int phys_k : phys_lengths_group_k) + { + for(const auto& rkpair : ratio_pairs) + { + double rq = rkpair.first; + double rk = rkpair.second; + std::vector seqlen_qs(b), seqlen_ks(b), seqlen_qpads(b), + seqlen_kpads(b); + for(int i = 0; i < b; ++i) + { + seqlen_qpads[i] = phys_q; + seqlen_kpads[i] = phys_k; + seqlen_qs[i] = logical_len(phys_q, rq); + seqlen_ks[i] = logical_len(phys_k, rk); + } + std::array, std::vector>, 3> pad_variants{ + std::pair{seqlen_qpads, seqlen_kpads}, // both + std::pair{seqlen_qpads, seqlen_ks}, // only q padding + std::pair{seqlen_qs, seqlen_kpads} // only kv padding + }; + for(const auto& mask : masks) + { + for(const auto& pv : pad_variants) + { + params.emplace_back(PaddingParam{mode_enum::group, + b, + hc.nhead, + hc.nhead_k, + seqlen_qs, + seqlen_ks, + pv.first, + pv.second, + {}, + {}, + true, + true, + mask}); + } + } + } + // Single-token logical length case + for(const auto& mask : masks) + { + std::vector seqlen_qs(b, 1), seqlen_ks(b, 1); + std::vector seqlen_qpads(b, phys_q), seqlen_kpads(b, phys_k); + // both padding variant only (others degenerate) + params.emplace_back(PaddingParam{mode_enum::group, + b, + hc.nhead, + hc.nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + {}, + {}, + true, + true, + mask}); + } + } + } + } + } + + return params; +} + +static const std::vector kPaddingParams = BuildPaddingParams(); + +INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd_Padding, PaddingCases, ValuesIn(kPaddingParams)); + +TEST_P(PaddingCases, Test) +{ + if constexpr(std::is_same_v) + { + GTEST_SKIP() << "Skip for fp8"; + } + + auto [mode, + batch, + nhead, + nhead_k, + seqlen_qs, + seqlen_ks, + seqlen_qpads, + seqlen_kpads, + q_eff_lens, + kv_eff_lens, + i_perm, + o_perm, + mask_str] = GetParam(); + + // For batch mode we wrap single logical lengths with adjust_seqlen. + std::vector adj_qs = + (mode == mode_enum::batch) ? std::vector{adjust_seqlen(seqlen_qs.at(0))} : seqlen_qs; + std::vector adj_ks = + (mode == mode_enum::batch) ? std::vector{adjust_seqlen(seqlen_ks.at(0))} : seqlen_ks; + + const int hdim_q = 64; + const int hdim_v = 64; + const int seqlen_knew = 0; + + auto result = fmha_fwd_run(mode, + batch, + nhead, + nhead_k, + adj_qs, + adj_ks, + hdim_q, + hdim_v, + seqlen_knew, // seqlen_knew + seqlen_qpads, // seqlen_qpads + seqlen_kpads, // seqlen_kpads + q_eff_lens, // q_eff_lens_per_batch + kv_eff_lens, // kv_eff_lens_per_batch + 0, // rotary_dim + i_perm, // i_perm + o_perm, // o_perm + 0, // scale_s + 0, // logits_soft_cap + def_is_v_rowmajor, + def_lse, // lse + 0, // page_block_size + false, // use_cache_batch_idx + "n", // bias_str + 0.0f, // p_drop + 0, // drop_seed + 0, // drop_offset + false, // drop_prefs + mask_str, // mask_str + squant, + true, // is_rotary_interleaved + 1, // num_splits + COMMON_ARGS); + CHECK_RESULT(result); +} diff --git a/test/ck_tile/gemm_multi_abd/CMakeLists.txt b/test/ck_tile/gemm_multi_abd/CMakeLists.txt index ac3b59d5d3..8f9b694a3b 100644 --- a/test/ck_tile/gemm_multi_abd/CMakeLists.txt +++ b/test/ck_tile/gemm_multi_abd/CMakeLists.txt @@ -5,8 +5,8 @@ if(CK_USE_OCP_FP8) endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - add_gtest_executable(test_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) - add_gtest_executable(test_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) - target_compile_definitions(test_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - target_compile_definitions(test_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_ck_tile_gemm_multi_abd_cshuffle test_gemm_multi_abd_cshuffle.cpp) + add_gtest_executable(test_ck_tile_gemm_multi_abd_default2d test_gemm_multi_abd_default2d.cpp) + target_compile_definitions(test_ck_tile_gemm_multi_abd_cshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_definitions(test_ck_tile_gemm_multi_abd_default2d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp index 9821963458..87d6a9101c 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_cshuffle.cpp @@ -24,14 +24,16 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp index b3a89aba05..f2476e803f 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -22,17 +22,17 @@ using KernelTypes = ::testing::Types< // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, - std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + + // Currently MultiABD kernel doesn't support F8 data type + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + //std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, >; // clang-format on diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc index 5aa113608f..33eb404fbe 100644 --- a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -1,105 +1,5 @@ #pragma once -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) -{ - constexpr int M = 256; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) -{ - constexpr int M = 512; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) -{ - constexpr int M = 512; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) -{ - constexpr int M = 256; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) -{ - constexpr int M = 512; - constexpr int N = 768; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) -{ - constexpr int M = 512; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) -{ - constexpr int M = 256; - constexpr int N = 1280; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) -{ - constexpr int M = 768; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) -{ - constexpr int M = 1280; - constexpr int N = 512; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - -TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) -{ - constexpr int M = 1280; - constexpr int N = 256; - constexpr int K = 256; - constexpr int kBatch = 1; - - EXPECT_EQ(this->Run(M, N, K, kBatch), true); -} - TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) { constexpr int M = 512; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 62f819ac1e..22d83306c3 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -116,19 +116,6 @@ class TestCkTileGemmPipeline : public ::testing::Test template void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { - // TODO: This should be parameterized in tests - // constexpr ck_tile::index_t M_Tile = 128; - // constexpr ck_tile::index_t N_Tile = 128; - // constexpr ck_tile::index_t K_Tile = 128; - - // constexpr ck_tile::index_t M_Warp = 1; - // constexpr ck_tile::index_t N_Warp = 4; - // constexpr ck_tile::index_t K_Warp = 1; - - // constexpr ck_tile::index_t M_Warp_Tile = 32; - // constexpr ck_tile::index_t N_Warp_Tile = 32; - // constexpr ck_tile::index_t K_Warp_Tile = sizeof(ADataType) == 2 ? 16 : 32; - constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 6893318ea2..f8c726794c 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -62,10 +62,10 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; - using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>; std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } template @@ -436,8 +436,18 @@ class TestCkTileGroupedGemm : public ::testing::Test const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + gemm_descs.push_back({p_a, + p_b, + {/*ds_ptr*/}, + p_c, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + {/*stride_Ds*/}, + stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; @@ -446,7 +456,7 @@ class TestCkTileGroupedGemm : public ::testing::Test if constexpr(Persistent) { // Generate kernel arguments - std::vector kargs; + std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); const bool splitk = gemm_descs[0].k_batch > 1; for(const auto& arg : gemm_descs) @@ -468,7 +478,7 @@ class TestCkTileGroupedGemm : public ::testing::Test ck_tile::hip_check_error( hipMemcpyWithStream(kargs_ptr, kargs.data(), - kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>), hipMemcpyHostToDevice, stream.stream_id_)); #if CK_TILE_USE_WMMA diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt new file mode 100644 index 0000000000..20c4cbc1c3 --- /dev/null +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -0,0 +1,9 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp) + target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() \ No newline at end of file diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp new file mode 100644 index 0000000000..deea2fc852 --- /dev/null +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_multi_d_util.hpp" + +using F16 = ck_tile::half_t; +using F8 = ck_tile::fp8_t; +using F32 = float; + +// Custom tuple-like structure for kernel configuration +template +struct KernelConfig +{ + using ALayoutType = ALayout_; + using BLayoutType = BLayout_; + using ELayoutType = ELayout_; + using DsLayoutType = ck_tile::tuple; + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using EDataType = EDataType_; + using DsDataType = ck_tile::tuple; + + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int M_Warp_ = M_Warp_val_; + static constexpr int N_Warp_ = N_Warp_val_; + static constexpr int K_Warp_ = K_Warp_val_; + static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_; + static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_; + static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_; + static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_; + static constexpr auto Scheduler_ = Scheduler_val_; + static constexpr PipelineType Pipeline_ = Pipeline_val_; + static constexpr int BlockPerCu_ = 1; +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory>, // memory + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3>, // v3 + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4> // v4 + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmMultiD, KernelTypes); + +#include "test_grouped_gemm_multi_d_ut_cases.inc" diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_ut_cases.inc b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_ut_cases.inc new file mode 100644 index 0000000000..9c3a33cf59 --- /dev/null +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_ut_cases.inc @@ -0,0 +1,91 @@ +#pragma once + +TYPED_TEST(TestCkTileGroupedGemmMultiD, K256) +{ + const int group_count = 7; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Es; + std::vector stride_D0; + std::vector stride_D1; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 256 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Es.push_back(Ns[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + } + + this->Run( + Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count); +} + +TYPED_TEST(TestCkTileGroupedGemmMultiD, K128) +{ + const int group_count = 5; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Es; + std::vector stride_D0; + std::vector stride_D1; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Es.push_back(Ns[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + } + + this->Run( + Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count); +} + +TYPED_TEST(TestCkTileGroupedGemmMultiD, LargeMNK_8Groups) +{ + const int group_count = 8; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Es; + std::vector stride_D0; + std::vector stride_D1; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(512 + 256 * i); + Ns.push_back(512 + 256 * i); + Ks.push_back(768 + 256 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Es.push_back(Ns[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + } + + this->Run( + Ms, Ns, Ks, stride_As, stride_Bs, stride_Es, stride_D0, stride_D1, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp new file mode 100644 index 0000000000..4c13b4a7f7 --- /dev/null +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -0,0 +1,431 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" + +enum class PipelineType +{ + Memory = 0, + CompV3 = 1, + CompV4 = 2 +}; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +class TestCkTileGroupedGemmMultiD : public ::testing::Test +{ + protected: + using ALayout = typename Config::ALayoutType; + using BLayout = typename Config::BLayoutType; + using ELayout = typename Config::ELayoutType; + using DsLayout = typename Config::DsLayoutType; + using ADataType = typename Config::ADataType; + using BDataType = typename Config::BDataType; + using AccDataType = typename Config::AccDataType; + using EDataType = typename Config::EDataType; + using PrecType = BDataType; + using DsDataType = typename Config::DsDataType; + using D0DataType = std::tuple_element_t<0, DsDataType>; + using D1DataType = std::tuple_element_t<1, DsDataType>; + using D0Layout = std::tuple_element_t<0, DsLayout>; + using D1Layout = std::tuple_element_t<1, DsLayout>; + + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = false; + + static constexpr bool TransposeC = false; // transpose c is not supported + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = std:: + conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + inline std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + } + + template + void invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + // for testing purposes, we can hardcode the values here as we what is compatible with + // pipeline + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * Config::K_Tile_; + const ck_tile::index_t K_split = + (gemm_descs[0].K + k_grain - 1) / k_grain * Config::K_Tile_; + const ck_tile::index_t num_loop = + ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Es, + std::vector& stride_D0, + std::vector& stride_D1, + const int kbatch = 1, + const int group_count = 16) + { + + using namespace ck_tile::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> e_m_n_tensors; + std::vector> d0_m_n_tensors; + std::vector> d1_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + e_m_n_tensors.reserve(group_count); + d0_m_n_tensors.reserve(group_count); + d1_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> e_m_n_dev_buf; + std::vector> d0_m_n_dev_buf; + std::vector> d1_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + e_m_n_dev_buf.reserve(group_count); + d0_m_n_dev_buf.reserve(group_count); + d1_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{}); + stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{}); + stride_Es[i] = f_get_default_stride(M, N, stride_Es[i], ELayout{}); + stride_D0[i] = f_get_default_stride(M, N, stride_D0[i], D0Layout{}); + stride_D1[i] = f_get_default_stride(M, N, stride_D1[i], D1Layout{}); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, K, stride_As[i], ALayout{}))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{}))); + e_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_Es[i], ELayout{}))); + d0_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_D0[i], D0Layout{}))); + d1_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_D1[i], D1Layout{}))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc + << " e_m_n: " << e_m_n_tensors[i].mDesc + << " d0_m_n: " << d0_m_n_tensors[i].mDesc + << " d1_m_n: " << d1_m_n_tensors[i].mDesc << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-2.f, 2.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_k_n_tensors[i].get_element_space_size_in_bytes())); + e_m_n_dev_buf.push_back(std::make_unique( + e_m_n_tensors[i].get_element_space_size_in_bytes())); + d0_m_n_dev_buf.push_back(std::make_unique( + d0_m_n_tensors[i].get_element_space_size_in_bytes())); + d1_m_n_dev_buf.push_back(std::make_unique( + d1_m_n_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data()); + e_m_n_dev_buf[i]->SetZero(); + d0_m_n_dev_buf[i]->ToDevice(d0_m_n_tensors[i].data()); + d1_m_n_dev_buf[i]->ToDevice(d1_m_n_tensors[i].data()); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer(); + + std::array ds_ptr_buf = { + d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()}; + std::array stridesDs = {stride_D0[i], + stride_D1[i]}; + + gemm_descs.push_back({p_a, + p_b, + ds_ptr_buf, + p_e, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + stridesDs, + stride_Es[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data()); + } + + std::vector> e_m_n_host_refs; + e_m_n_host_refs.reserve(group_count); + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + e_m_n_host_refs.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], ELayout{}))); + + e_m_n_host_refs[i].SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tensors[i], + b_k_n_tensors[i], + {d0_m_n_tensors[i], d1_m_n_tensors[i]}, + e_m_n_host_refs[i]); + const float max_accumulated_value = + *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); + + const auto rtol_atol = calculate_rtol_atol(Ks[i], 1, max_accumulated_value); + + pass &= + ck_tile::check_err(e_m_n_tensors[i], + e_m_n_host_refs[i], + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + EXPECT_TRUE(pass); + } +}; diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 799a5f2907..d2f64920fd 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -88,10 +88,10 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } - using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>; inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } template @@ -333,8 +333,18 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + gemm_descs.push_back({p_a, + p_b, + {/*ds_ptr*/}, + p_c, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + {/*stride_Ds*/}, + stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; diff --git a/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp b/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp index df8b77aba1..36d31d53fa 100644 --- a/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp +++ b/test/conv_tensor_rearrange/test_conv_tensor_rearrange_interface.cpp @@ -188,7 +188,7 @@ TEST_F(TestConvTensorRearrangeInterface1ScalarPerVector, X1ScalarPerVector) is_supported = this->template Run(); EXPECT_TRUE(is_supported); // vector load C % ScalarPerVector, dilation - this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {2}, {0}, {0}}; + this->conv_param = {1, 1, 1, 1, 1, {4}, {8}, {1}, {2}, {0}, {0}}; is_supported = this->template Run(); EXPECT_TRUE(is_supported); is_supported = this->template Run(); @@ -234,7 +234,7 @@ TEST_F(TestConvTensorRearrangeInterface4ScalarPerVector, X4ScalarPerVector) is_supported = this->template Run(); EXPECT_FALSE(is_supported); // vector load C % ScalarPerVector, dilation - this->conv_param = {1, 1, 1, 1, 1, {4}, {3}, {1}, {2}, {0}, {0}}; + this->conv_param = {1, 1, 1, 1, 1, {4}, {8}, {1}, {2}, {0}, {0}}; is_supported = this->template Run(); EXPECT_FALSE(is_supported); is_supported = this->template Run(); @@ -250,13 +250,13 @@ TEST_F(TestConvTensorRearrangeInterface4ScalarPerVector, X4ScalarPerVector) TEST_F(TestConvTensorRearrangeInterface4ScalarPerVectorFakeC, X4ScalarPerVectorFakeC) { // C = 3 - this->conv_param = {1, 1, 1, 1, 3, {4}, {3}, {1}, {1}, {0}, {0}}; + this->conv_param = {1, 1, 1, 1, 3, {4}, {5}, {1}, {1}, {0}, {0}}; bool is_supported = this->template Run(); EXPECT_FALSE(is_supported); is_supported = this->template Run(); EXPECT_FALSE(is_supported); // C = 4 - this->conv_param = {1, 1, 1, 1, 8, {4}, {3}, {1}, {1}, {0}, {0}}; + this->conv_param = {1, 1, 1, 1, 8, {4}, {5}, {1}, {1}, {0}, {0}}; is_supported = this->template Run(); EXPECT_TRUE(is_supported); is_supported = this->template Run(); diff --git a/test/gemm_multi_abd/CMakeLists.txt b/test/gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..d700414b05 --- /dev/null +++ b/test/gemm_multi_abd/CMakeLists.txt @@ -0,0 +1,9 @@ +add_gtest_executable(test_gemm_multi_abd_wmma test_gemm_multi_abd_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multi_abd_wmma PRIVATE utility device_gemm_multi_abd_instance) +endif() + +add_gtest_executable(test_gemm_multi_abd_xdl test_gemm_multi_abd_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multi_abd_xdl PRIVATE utility device_gemm_multi_abd_instance) +endif() diff --git a/test/gemm_multi_abd/test_gemm_common.hpp b/test/gemm_multi_abd/test_gemm_common.hpp new file mode 100644 index 0000000000..030fbcac77 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_common.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmCommon : public ::testing::Test +{ + protected: + using AsLayout = std::tuple_element_t<0, Tuple>; + using BsLayout = std::tuple_element_t<1, Tuple>; + using DsLayout = std::tuple_element_t<2, Tuple>; + using ELayout = Row; + using AsDataType = std::tuple_element_t<3, Tuple>; + using BsDataType = std::tuple_element_t<4, Tuple>; + using DsDataType = std::tuple_element_t<5, Tuple>; + using EDataType = std::tuple_element_t<6, Tuple>; + using AElementOp = std::tuple_element_t<7, Tuple>; + using BElementOp = std::tuple_element_t<8, Tuple>; + using CDEElementOp = std::tuple_element_t<9, Tuple>; + + void Run() + { + std::vector> lengths = { + {16, 32, 64}, {512, 1024, 2048}, {1024, 512, 32}}; + + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + // Assuming same layout for all A matrices (same applies for Bs and Ds) + int StrideA = ck::is_same_v>, Row> ? K : M; + int StrideB = ck::is_same_v>, Row> ? N : K; + // In case no D matrices are provided, set stride to 0 + int StrideD = 0; + if constexpr(DsDataType::Size() > 0) + { + StrideD = ck::is_same_v>, Row> ? N : M; + } + int StrideE = ck::is_same_v ? N : M; + + all_success = + all_success & ck::profiler::profile_gemm_multi_abd_impl( + 1, 2, false, false, M, N, K, StrideA, StrideB, StrideD, StrideE); + } + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp new file mode 100644 index 0000000000..a15f95bbf8 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_multi_abd_wmma.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "test_gemm_common.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using Add = ck::tensor_operation::element_wise::Add; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; + +using KernelTypesABD = ::testing::Types< +#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, +#endif + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; + +TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); +TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } + +} // namespace test +} // namespace ck diff --git a/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp new file mode 100644 index 0000000000..a15f95bbf8 --- /dev/null +++ b/test/gemm_multi_abd/test_gemm_multi_abd_xdl.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multi_abd_impl.hpp" +#include "test_gemm_common.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using I8 = int8_t; +using BF16 = ck::bhalf_t; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Multiply = ck::tensor_operation::element_wise::Multiply; +using Add = ck::tensor_operation::element_wise::Add; +using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; +using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu; +using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; + +using KernelTypesABD = ::testing::Types< +#if 0 // TBD: skip temporary because they fail HostTensdorDescriptor validation + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + Add>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + Multiply, + AddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + FastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, + std::tuple, + ck::Tuple, + ck::Tuple<>, + ck::Tuple, + ck::Tuple, + ck::Tuple<>, + BF16, + PassThrough, + Multiply, + PassThrough>, +#endif + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAddFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyAdd>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + MultiplyFastGelu>, + std::tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Multiply>>; + +TYPED_TEST_SUITE(TestGemmCommon, KernelTypesABD); +TYPED_TEST(TestGemmCommon, Test_BF16I8BF16) { this->Run(); } + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index f4011cf998..3a42638e30 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -2,7 +2,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) { - const std::vector Ms{0, 1}; + const std::vector Ms{2, 1}; constexpr int N = 768; constexpr int K = 544; @@ -14,7 +14,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) TYPED_TEST(TestGroupedGemm, SmallCases) { - const std::vector Ms{2, 1, 3, 4, 5, 0}; + const std::vector Ms{2, 1, 3, 4, 5}; constexpr int N = 768; constexpr int K = 544;