diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1d0f7df3c6..af36f492ba 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @shumway @vidyasagar-amd +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd # Documentation files docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD *.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD diff --git a/.github/scripts/therock_configure_ci.py b/.github/scripts/therock_configure_ci.py index 557afe2d84..cc66fdbfe8 100644 --- a/.github/scripts/therock_configure_ci.py +++ b/.github/scripts/therock_configure_ci.py @@ -42,6 +42,24 @@ def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]: file=sys.stderr, ) return None + +GITHUB_WORKFLOWS_CI_PATTERNS = [ + "therock*", +] + +def is_path_workflow_file_related_to_ci(path: str) -> bool: + return any( + fnmatch.fnmatch(path, ".github/workflows/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) or any( + fnmatch.fnmatch(path, ".github/scripts/" + pattern) + for pattern in GITHUB_WORKFLOWS_CI_PATTERNS + ) + +def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool: + if paths is None: + return False + return any(is_path_workflow_file_related_to_ci(p) for p in paths) # Paths matching any of these patterns are considered to have no influence over # build or test workflows so any related jobs can be skipped if all paths @@ -82,12 +100,16 @@ def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool: ) other_paths = paths_set - github_workflows_paths + related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths) contains_other_non_skippable_files = check_for_non_skippable_path(other_paths) print("should_ci_run_given_modified_paths findings:") print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}") - if contains_other_non_skippable_files: + if related_to_ci: + print("Enabling build jobs since a related workflow file was modified") + return True + elif contains_other_non_skippable_files: print("Enabling TheRock CI jobs since a non-skippable path was modified") return True else: diff --git a/.github/workflows/therock-ci-linux.yml b/.github/workflows/therock-ci-linux.yml index 7db124d2a1..271c6376ca 100644 --- a/.github/workflows/therock-ci-linux.yml +++ b/.github/workflows/therock-ci-linux.yml @@ -26,31 +26,49 @@ jobs: AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} TEATIME_FORCE_INTERACTIVE: 0 AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini + CACHE_DIR: ${{ github.workspace }}/.container-cache + # The ccache.conf will be written by setup_ccache.py before this gets used. + CCACHE_CONFIGPATH: ${{ github.workspace }}/.ccache/ccache.conf steps: + - name: "Checking out repository for rocm-libraries" + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + repository: "ROCm/rocm-libraries" + - name: Checkout composable_kernel repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: "composable_kernel" - name: Checkout TheRock repository uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: repository: "ROCm/TheRock" - ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57 + ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit path: "TheRock" + - name: Setup ccache + run: | + ./TheRock/build_tools/setup_ccache.py \ + --config-preset "github-oss-presubmit" \ + --dir "$(dirname $CCACHE_CONFIGPATH)" \ + --local-path "$CACHE_DIR/ccache" + echo "namespace = ext_composable_kernel" >> $CCACHE_CONFIGPATH + echo "[*] ccache_config contents:" + cat $CCACHE_CONFIGPATH + - name: Runner Health Settings run: | - df -h - cmake --version - echo "Installed Python versions:" - ls -d /opt/python - echo "python: $(which python), python3: $(which python3)" - echo "Git version: $(git --version)" - git config --global --add safe.directory $PWD - git config fetch.parallel 10 - + ./TheRock/build_tools/health_status.py + - name: Fetch sources run: | - ./TheRock/build_tools/fetch_sources.py --jobs 12 + ./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks + + - name: Patch rocm-libraries + run: | + git config --global --add safe.directory '*' + git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch - name: Install python deps run: | @@ -84,6 +102,10 @@ jobs: echo "Artifacts:" echo "----------" du -h -d 1 TheRock/build/artifacts + echo "CCache Stats:" + echo "-------------" + ccache -s -v + tail -v -n +1 .ccache/compiler_check_cache/* > TheRock/build/logs/ccache_compiler_check_cache.log - name: Configure AWS Credentials for non-forked repos if: ${{ always() && !github.event.pull_request.head.repo.fork }} @@ -92,32 +114,14 @@ jobs: aws-region: us-east-2 role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external - - name: Create Logs index Files and upload logs + - name: Post Build Upload if: always() run: | - python3 TheRock/build_tools/github_actions/create_log_index.py \ - --build-dir=TheRock/build \ - --amdgpu-family=${{ env.AMDGPU_FAMILIES }} - - python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \ - --build-dir=TheRock/build \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} - - - name: Upload artifacts - run: | - python TheRock/build_tools/github_actions/upload_build_artifacts.py \ + python3 TheRock/build_tools/github_actions/post_build_upload.py \ --run-id ${{ github.run_id }} \ --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build - - - name: Add Links to Job Summary - if: always() - run: | - python TheRock/build_tools/github_actions/upload_build_summary.py \ - --run-id ${{ github.run_id }} \ - --amdgpu-family ${{ env.AMDGPU_FAMILIES }} \ - --build-dir TheRock/build + --build-dir TheRock/build \ + --upload therock-test-linux: name: "Test" diff --git a/.github/workflows/therock-ci.yml b/.github/workflows/therock-ci.yml index 3232652b6b..40a3b0bec8 100644 --- a/.github/workflows/therock-ci.yml +++ b/.github/workflows/therock-ci.yml @@ -56,7 +56,14 @@ jobs: uses: ./.github/workflows/therock-ci-linux.yml secrets: inherit with: - cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../" + cmake_options: >- + -DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON + -DTHEROCK_ENABLE_MIOPEN=ON + -DTHEROCK_ENABLE_ALL=OFF + -DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON + -DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel + -DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON + -DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../ amdgpu_families: "gfx94X-dcgpu" test_runs_on: "linux-mi325-1gpu-ossci-rocm" diff --git a/.github/workflows/therock-test-component.yml b/.github/workflows/therock-test-component.yml new file mode 100644 index 0000000000..068dbe3033 --- /dev/null +++ b/.github/workflows/therock-test-component.yml @@ -0,0 +1,71 @@ +name: Test component + +on: + workflow_call: + inputs: + artifact_run_id: + type: string + default: "" + amdgpu_families: + type: string + test_runs_on: + type: string + platform: + type: string + component: + type: string + + +permissions: + contents: read + +jobs: + test_component: + name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})' + runs-on: ${{ inputs.test_runs_on }} + container: + image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }} + options: --ipc host + --group-add video + --device /dev/kfd + --device /dev/dri + --group-add 110 + --env-file /etc/podinfo/gha-gpu-isolation-settings + strategy: + fail-fast: false + matrix: + # The shard array is based on "total_shards" from "fetch_test_configurations.py" + # The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards) + shard: ${{ fromJSON(inputs.component).shard_arr }} + defaults: + run: + shell: bash + env: + VENV_DIR: ${{ github.workspace }}/.venv + ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}" + OUTPUT_ARTIFACTS_DIR: "./build" + THEROCK_BIN_DIR: "./build/bin" + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + steps: + - name: Checkout Repository + uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + repository: "ROCm/TheRock" + + - name: Run setup test environment workflow + uses: './.github/actions/setup_test_environment' + with: + ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} + AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} + OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} + VENV_DIR: ${{ env.VENV_DIR }} + FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }} + IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} + + - name: Test + timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }} + env: + SHARD_INDEX: ${{ matrix.shard }} + TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }} + run: | + ${{ fromJSON(inputs.component).test_script }} diff --git a/.github/workflows/therock-test-packages.yml b/.github/workflows/therock-test-packages.yml index 37ddd399ad..54e068eb3d 100644 --- a/.github/workflows/therock-test-packages.yml +++ b/.github/workflows/therock-test-packages.yml @@ -37,41 +37,17 @@ jobs: test_components: name: 'Test ${{ matrix.components.job_name }}' - runs-on: ${{ inputs.test_runs_on }} - needs: configure_test_matrix + needs: [configure_test_matrix] # skip tests if no test matrix to run if: ${{ needs.configure_test_matrix.outputs.components != '[]' }} strategy: fail-fast: false matrix: components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }} - defaults: - run: - shell: bash - env: - VENV_DIR: ${{ github.workspace }}/.venv - ARTIFACT_RUN_ID: "${{ github.run_id }}" - OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build - THEROCK_BIN_DIR: "./build/bin" - steps: - - name: Checkout Repository - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - repository: "ROCm/TheRock" - - - name: Run setup test environment workflow - uses: './.github/actions/setup_test_environment' - with: - ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }} - AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }} - OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }} - VENV_DIR: ${{ env.VENV_DIR }} - FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }} - PLATFORM: ${{ inputs.platform }} - IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }} - - - name: Test - timeout-minutes: ${{ matrix.components.timeout_minutes }} - run: | - if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi - ${{ matrix.components.test_script }} + uses: './.github/workflows/therock-test-component.yml' + with: + artifact_run_id: ${{ github.run_id }} + amdgpu_families: ${{ inputs.amdgpu_families }} + test_runs_on: ${{ inputs.test_runs_on }} + platform: ${{ inputs.platform }} + component: ${{ toJSON(matrix.components) }} diff --git a/CHANGELOG.md b/CHANGELOG.md index f21795012d..9de78f3043 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,9 +5,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added a compute async pipeline in the CK TILE universal GEMM on gfx950 * Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. +* Added the new api to load different memory sizes to SGPR. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. +* Added support for grouped_gemm kernels to perform multi_d elementwise operation. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data * Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. * Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). @@ -31,7 +34,10 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. -* Added tensor-wise quantization for CK_TILE GEMM +* Added support for f32 to FMHA (fwd/bwd). +* Added tensor-wise quantization for CK_TILE GEMM. +* Added support for batched contraction kernel. +* Added pooling kernel in CK_TILE ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index 26d91fe6d8..f4d3a83c34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,9 @@ rocm_check_target_ids(SUPPORTED_GPU_TARGETS message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") +# Cache SUPPORTED_GPU_TARGETS for debug +set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") + if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) @@ -339,6 +342,7 @@ endif() option(USE_BITINT_EXTENSION_INT4 "Whether to enable clang's BitInt extension to provide int4 data type." OFF) option(USE_OPT_GFX11 "Whether to enable LDS cumode and Wavefront32 mode for GFX11 silicons." OFF) option(ENABLE_ASM_DUMP "Whether to enable assembly dump for kernels." OFF) +option(ENABLE_JSON_DUMP "Whether to enable json dump for examples." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -352,6 +356,11 @@ if(ENABLE_ASM_DUMP) message("CK compiled with ENABLE_ASM_DUMP set to ${ENABLE_ASM_DUMP}") endif() +if (ENABLE_JSON_DUMP) + add_compile_definitions(CK_ENABLE_JSON_DUMP) + message("CK compiled with ENABLE_JSON_DUMP set to ${ENABLE_JSON_DUMP}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Dockerfile b/Dockerfile index 6f5cd0115d..07327442fe 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,27 +1,23 @@ + FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.4.1 +ARG ROCMVERSION=7.0.1 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive # Add rocm repository RUN set -xe && \ - apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ - curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg + apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl -RUN if [ "$ROCMVERSION" != "6.5" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60401-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60401-1_all.deb && \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ - fi - -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ - amdgpu-install -y --usecase=rocm --no-dkms +RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \ + apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \ + apt update && \ + apt install python3-setuptools python3-wheel -y && \ + apt install rocm-dev -y ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache @@ -45,7 +41,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libelf-dev \ libnuma-dev \ libpthread-stubs0-dev \ - llvm-amdgpu \ mpich \ net-tools \ pkg-config \ @@ -61,17 +56,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- zip \ libzstd-dev \ openssh-server \ - clang-format-12 \ clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ rm -rf amdgpu-install* && \ -# Remove unnecessary rocm components that take a lot of space - apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt - #Install latest ccache -RUN git clone https://github.com/ccache/ccache.git && \ + git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ diff --git a/Dockerfile.compiler b/Dockerfile.compiler index 0306057e45..47bd8294b6 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4.1" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index efe08a7d41..3fbcdb5849 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -46,6 +46,58 @@ def runShell(String command){ return (output != "") } +def shouldRunCICheck() { + // Define patterns for files that should not trigger CI + def skipFilePatterns = [ + /^\.github\/.*/, // GitHub workflow files + /^docs\/.*/, // Documentation files + /^LICENSE$/, // License file + /^.*\.gitignore$/, // Git ignore files + /.*\.md$/ // Markdown files + ] + + try { + // Get the list of changed files + def changedFiles = sh( + returnStdout: true, + script: ''' + if [ "$CHANGE_ID" != "" ]; then + # For PR builds, compare against target branch + git diff --name-only origin/$CHANGE_TARGET...HEAD + else + # For regular builds, compare against previous commit + git diff --name-only HEAD~1..HEAD + fi + ''' + ).trim().split('\n') + + if (changedFiles.isEmpty() || (changedFiles.size() == 1 && changedFiles[0].trim().isEmpty())) { + echo "No changed files detected - this might be a manual trigger or merge commit, running CI for safety" + return true + } + + echo "Changed files: ${changedFiles.join(', ')}" + + // Check if any changed files are not in the skip patterns + def hasFilesRequiringCI = changedFiles.any { file -> + !skipFilePatterns.any { pattern -> + file ==~ pattern + } + } + + if (hasFilesRequiringCI) { + echo "Found files that require CI" + return true + } else { + echo "Only non-relevant files changed, skipping CI" + return false + } + } catch (Exception e) { + echo "Error checking changed files: ${e.getMessage()}, running CI by default" + return true + } +} + def getBaseDockerImageName(){ def img if (params.USE_CUSTOM_DOCKER != ""){ @@ -53,7 +105,7 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = parseVersion("${params.ROCMVERSION}") - if ( ROCM_numeric.major <= 6 && ROCM_numeric.minor < 5 ){ + if ( ROCM_numeric.major <= 7 && ROCM_numeric.minor < 1 ){ img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ @@ -476,7 +528,7 @@ def buildHipClangJob(Map conf=[:]){ def retimage (retimage, image) = getDockerImage(conf) - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 20, unit: 'HOURS') { @@ -538,7 +590,7 @@ def Build_CK(Map conf=[:]){ def image def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { @@ -728,7 +780,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -836,7 +888,7 @@ def run_aiter_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -852,13 +904,14 @@ def run_aiter_tests(Map conf=[:]){ } withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 2, unit: 'HOURS'){ + timeout(time: 5, unit: 'HOURS'){ try{ sh "rocminfo" sh "python3 --version" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py" + sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha_varlen.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py" sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py" @@ -894,7 +947,7 @@ def run_pytorch_tests(Map conf=[:]){ dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " echo "Docker flags: ${dockerOpts}" - gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { try { echo "Pulling image: ${image}" @@ -930,13 +983,14 @@ def run_pytorch_tests(Map conf=[:]){ } //launch develop branch daily jobs -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true - 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true - 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false - 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false''' : "" +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_PERFORMANCE_TESTS=true;FORCE_CI=true + 0 22 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true;FORCE_CI=true + 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;FORCE_CI=true + 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;FORCE_CI=true + 0 11 * * * % RUN_PYTORCH_TESTS=true;RUN_CODEGEN_TESTS=false;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false;BUILD_GFX10=false;BUILD_GFX11=false;BUILD_GFX12=false;BUILD_GFX90A=false;FORCE_CI=true''' : "" pipeline { agent none @@ -957,8 +1011,8 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.4.1', - description: 'Specify which ROCM version to use: 6.4.1 (default).') + defaultValue: '7.0.1', + description: 'Specify which ROCM version to use: 7.0.1 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', @@ -985,8 +1039,8 @@ pipeline { description: "Use the CK build to verify hipTensor build and tests (default: OFF)") string( name: 'hipTensor_branch', - defaultValue: 'mainline', - description: 'Specify which branch of hipTensor to use (default: mainline)') + defaultValue: 'develop', + description: 'Specify which branch of hipTensor to use (default: develop)') booleanParam( name: "USE_SCCACHE", defaultValue: true, @@ -1037,8 +1091,8 @@ pipeline { description: "Build CK and run tests on gfx942 (default: ON)") booleanParam( name: "BUILD_GFX950", - defaultValue: false, - description: "Build CK and run tests on gfx950 (default: OFF)") + defaultValue: true, + description: "Build CK and run tests on gfx950 (default: ON)") booleanParam( name: "BUILD_GFX10", defaultValue: true, @@ -1091,6 +1145,10 @@ pipeline { name: 'ck_aiter_branch', defaultValue: 'develop', description: 'Specify which branch of CK to test with AITER (default: develop)') + booleanParam( + name: "FORCE_CI", + defaultValue: false, + description: "Force CI to run even when only non-relevant files are changed (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -1104,7 +1162,20 @@ pipeline { DOCKER_BUILDKIT = "1" } stages{ + stage("Determine CI Execution") { + agent{ label rocmnode("nogpu") } + steps { + script { + env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) + echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" + } + } + } stage("Build Docker"){ + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } @@ -1116,6 +1187,10 @@ pipeline { } } stage("Static checks") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel{ stage('Clang Format and Cppcheck') { when { @@ -1125,16 +1200,16 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\') && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1155,16 +1230,17 @@ pipeline { agent{ label rocmnode("nogpu") } environment{ setup_args = "NO_CK_BUILD" - execute_cmd = "find .. -not -path \'*.git*\' -iname \'*.h\' \ - -o -not -path \'*.git*\' -iname \'*.hpp\' \ - -o -not -path \'*.git*\' -iname \'*.cpp\' \ - -o -iname \'*.h.in\' \ - -o -iname \'*.hpp.in\' \ - -o -iname \'*.cpp.in\' \ - -o -iname \'*.cl\' \ + execute_cmd = "(cd .. && git ls-files \ + \'*.h\' \ + \'*.hpp\' \ + \'*.cpp\' \ + \'*.h.in\' \ + \'*.hpp.in\' \ + \'*.cpp.in\' \ + \'*.cl\' \ | grep -v 'build/' \ | grep -v 'include/rapidjson' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\')" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1175,6 +1251,10 @@ pipeline { } stage("Run Pytorch Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Pytorch Tests on gfx942") @@ -1193,6 +1273,10 @@ pipeline { } stage("Run AITER Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run AITER Tests on gfx942") @@ -1223,6 +1307,10 @@ pipeline { } stage("Run Grouped Conv Large Case Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Grouped Conv Large Case Tests on gfx90a") @@ -1247,6 +1335,10 @@ pipeline { } stage("Run Comprehensive Convolution Dataset Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Comprehensive Dataset Tests on gfx90a") @@ -1279,6 +1371,10 @@ pipeline { } stage("Run Codegen Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run Codegen Tests on gfx90a") @@ -1290,7 +1386,7 @@ pipeline { agent{ label rocmnode("gfx90a")} environment{ setup_args = "NO_CK_BUILD" - execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake ../codegen && \ + execute_args = """ CXX=/opt/rocm/llvm/bin/clang++ cmake -DCMAKE_PREFIX_PATH=/opt/rocm ../codegen && \ make -j64 check""" } steps{ @@ -1302,6 +1398,10 @@ pipeline { } stage("Run CK_TILE_FMHA Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run CK_TILE_FMHA Tests on gfx90a") @@ -1350,7 +1450,6 @@ pipeline { } agent{ label rocmnode("gfx950") } environment{ - def docker_name = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0" setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx950 && \ make -j128 tile_example_fmha_fwd tile_example_fmha_bwd && \ @@ -1358,7 +1457,7 @@ pipeline { example/ck_tile/01_fmha/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx950 """ } steps{ - buildHipClangJobAndReboot(setup_args:setup_args, docker_name: docker_name, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) cleanWs() } } @@ -1366,6 +1465,10 @@ pipeline { } stage("Run TILE_ENGINE_GEMM Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Run TILE_ENGINE_GEMM Tests on gfx90a") @@ -1483,6 +1586,10 @@ pipeline { stage("Build CK and run Tests") { + when { + beforeAgent true + expression { env.SHOULD_RUN_CI.toBoolean() } + } parallel { stage("Build CK with RHEL8") @@ -1566,7 +1673,7 @@ pipeline { -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } @@ -1631,7 +1738,7 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB}:ck_ub24.04_rocm7.0.1") } cleanWs() } @@ -1657,13 +1764,13 @@ pipeline { cleanWs() } } - stage("Build CK and run Tests on gfx1101") + stage("Build CK and run Tests on gfx11") { when { beforeAgent true expression { params.BUILD_GFX11.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_INSTANCES_ONLY.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx1101") } + agent{ label 'miopen && (gfx1101 || gfx1100)' } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx11-generic" -DUSE_OPT_GFX11=ON -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ @@ -1700,6 +1807,17 @@ pipeline { } } } + post { + success { + script { + // Report the parent stage build ck and run tests status + def variant = env.STAGE_NAME + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Reporting success status for build ck and run tests" + } + } + } + } } stage("Process Performance Test Results") { @@ -1717,6 +1835,22 @@ pipeline { } } } + post { + success { + script { + // Report the skipped parent's stage status + def parentVariant = "Process Performance Test Results" + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${parentVariant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Process Performance Test Results stage skipped." + } + // Report the skipped stage's status + def variant = "Process results" + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "${variant}", account: 'ROCm', repo: 'composable_kernel') { + echo "Process Performance Test Results stage skipped." + } + } + } + } } } } diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 6587f4c4be..41e2fa2cc0 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -12,6 +12,17 @@ FetchContent_Declare( GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) +FetchContent_Populate(GTest) + +# Patch googlemock/CMakeLists.txt to fix invalid include path +set(GMOCK_CMAKE "${gtest_SOURCE_DIR}/googlemock/CMakeLists.txt") +file(READ "${GMOCK_CMAKE}" GMOCK_CMAKE_CONTENT) +string(REPLACE [[gtest_SOURCE_DIR}/include]] + [[gtest_SOURCE_DIR}/googletest/include]] + GMOCK_CMAKE_CONTENT + "${GMOCK_CMAKE_CONTENT}") +file(WRITE "${GMOCK_CMAKE}" "${GMOCK_CMAKE_CONTENT}") + # Suppress ROCMChecks WARNING on GoogleTests set(ROCM_DISABLE_CHECKS FALSE) macro(rocm_check_toolchain_var var access value list_file) @@ -24,7 +35,7 @@ if(WIN32) set(gtest_force_shared_crt ON CACHE_INTERNAL "") endif() -set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(BUILD_GMOCK ON CACHE INTERNAL "") set(INSTALL_GTEST OFF CACHE INTERNAL "") # Store the current value of BUILD_SHARED_LIBS @@ -32,15 +43,12 @@ set(__build_shared_libs ${BUILD_SHARED_LIBS}) set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") set(ROCM_DISABLE_CHECKS TRUE) -FetchContent_MakeAvailable(GTest) +add_subdirectory(${gtest_SOURCE_DIR} ${gtest_BINARY_DIR}) set(ROCM_DISABLE_CHECKS FALSE) # Restore the old value of BUILD_SHARED_LIBS set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) -set(BUILD_GMOCK OFF CACHE INTERNAL "") -set(INSTALL_GTEST OFF CACHE INTERNAL "") - set(GTEST_CXX_FLAGS -Wno-undef -Wno-reserved-identifier @@ -71,3 +79,12 @@ target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) +if(TARGET gmock) + target_compile_options(gmock PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock PRIVATE GTEST_HAS_SEH=0) +endif() + +if(TARGET gmock_main) + target_compile_options(gmock_main PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock_main PRIVATE GTEST_HAS_SEH=0) +endif() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 2b2e6e2949..80429a781b 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -12,6 +12,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) +find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -27,7 +28,7 @@ add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers) +target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 6cfff30dbd..0440e8246b 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -27,7 +27,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 414683ffdf..66a0d98238 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -37,7 +37,7 @@ using DeviceGemmInstance1 = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffl // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 8, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 2, S<1, 16, 1, 16>, 4, ck::LoopScheduler::Interwave, ck::PipelineVersion::v1>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance1; diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index a996d034e6..2b04d14a11 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -30,7 +30,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| | // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeType>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v1>; // clang-format on diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 0c51a58037..af9b7978f5 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -31,7 +31,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // this instance has been tested working on gfx950 // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on @@ -55,4 +55,12 @@ using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index d149fd88f1..d5c42558c4 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = ck::half_t; using CDataType = ck::half_t; using AccDataType = float; #else - < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; + < F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 16, 128, 4, 4, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 4, 4, 7, 1>; using ADataType = float; using BDataType = float; using CDataType = float; @@ -185,7 +185,6 @@ int main(int argc, char* argv[]) auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -209,8 +208,7 @@ int main(int argc, char* argv[]) return 0; } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; diff --git a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp index d8672f6a0c..76a30657f0 100644 --- a/example/01_gemm/gemm_xdl_wavelet_fp16.cpp +++ b/example/01_gemm/gemm_xdl_wavelet_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -29,7 +29,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_WaveletM // ######| | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 8>; + < ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, F16, CDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1,8>, 4>; // clang-format on using DeviceGemmInstance = DeviceGemmInstance; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 08e2b8c15f..7fb0c1e812 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -2,7 +2,6 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck/library/utility/validation_common.hpp" // use macro to minimize code change #ifndef EXAMPLE_WITH_COMPUTE_DATATYPE @@ -29,11 +28,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -59,17 +58,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - try - { - ck::utils::validate_gemm_strides_abc( - M, N, K, StrideA, StrideB, StrideC); - } - catch(const std::runtime_error& e) - { - std::cerr << "Error: " << e.what() << std::endl; - return false; - } - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); diff --git a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp index bffa2e5640..992e7c19c8 100644 --- a/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp +++ b/example/03_gemm_bias_relu/gemm_bias_relu_xdl_fp16.cpp @@ -174,6 +174,9 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + const auto StrideD = std::is_same::value + ? d_m_n.mDesc.GetStrides()[0] + : d_m_n.mDesc.GetStrides()[1]; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; @@ -221,7 +224,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{0}, + std::array{static_cast(StrideD)}, StrideE, a_element_op, b_element_op, diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index e630f67837..4e98bf3034 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -32,7 +32,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm int { + if constexpr(std::is_same_v) + { + return static_cast(tensor.GetStrides()[0]); + } + else + { + return static_cast(tensor.GetStrides()[1]); + } + }; + + if(StrideA <= 0) + StrideA = fetch_leading_stride(a_m_k, ALayout{}); + if(StrideB <= 0) + StrideB = fetch_leading_stride(b_k_n, BLayout{}); + if(StrideD0 <= 0) + StrideD0 = fetch_leading_stride(d0_m_n, D0Layout{}); + if(StrideD1 <= 0) + StrideD1 = fetch_leading_stride(d1_m_n, D1Layout{}); + if(StrideE <= 0) + StrideE = fetch_leading_stride(e_m_n_host_result, ELayout{}); + switch(config.init_method) { case 0: break; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 036f288d0a..7142521c55 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -125,7 +125,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); problem_size = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index c4e7068499..4b290d02a2 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -23,7 +23,7 @@ using RsGlobalReduceOp = static constexpr auto ConvSpec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off template diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp index eb8b5c76d3..9e125c4e5d 100644 --- a/example/12_reduce/reduce_blockwise_two_call.cpp +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -100,13 +100,13 @@ int main(int argc, char* argv[]) const std::array reduceDims = {3, 4}; // const std::array invariantDims = {0, 1, 2}; - const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + std::vector inLengths_1 = {64, 320, 80, 4, 128}; // input lengths of the second reduction, which is also the output lengths of the first // reduction - const std::vector inLengths_2 = {64, 320, 80, 4}; + std::vector inLengths_2 = {64, 320, 80, 4}; - const std::vector outLengths = {64, 320, 80}; + std::vector outLengths = {64, 320, 80}; if(argc == 1) { @@ -114,11 +114,26 @@ int main(int argc, char* argv[]) init_method = 2; time_kernel = true; } - else if(argc == 4) + else if((argc == 4) || (argc == 9)) { do_verify = static_cast(argv[1]); init_method = atoi(argv[2]); time_kernel = static_cast(atoi(argv[3])); + if(argc == 9) + { + inLengths_1[0] = atoi(argv[4]); + inLengths_1[1] = atoi(argv[5]); + inLengths_1[2] = atoi(argv[6]); + inLengths_1[3] = atoi(argv[7]); + inLengths_1[4] = atoi(argv[8]); + inLengths_2[0] = inLengths_1[0]; + inLengths_2[1] = inLengths_1[1]; + inLengths_2[2] = inLengths_1[2]; + inLengths_2[3] = inLengths_1[3]; + outLengths[0] = inLengths_1[0]; + outLengths[1] = inLengths_1[1]; + outLengths[2] = inLengths_1[2]; + } } else { diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp index 3ce08fd2af..abbf1b29f7 100644 --- a/example/13_pool2d_fwd/pool2d_fwd_common.hpp +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -78,12 +78,12 @@ bool pool_test(bool do_verification, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}); + return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout); } }; diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp index 2585072dfe..5291f5ce69 100644 --- a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -115,12 +115,14 @@ int main() if(std::is_same::value) { return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1_uz})); + std::vector({stride, 1_uz}), + layout); } else { return HostTensorDescriptor(std::vector({row, col}), - std::vector({1_uz, stride})); + std::vector({1_uz, stride}), + layout); } }; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp index 680cee1f81..ac64a468a4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 90a12bc1dd..85ea8c2f2c 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index 28b0fcd0ce..fb047ae364 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp index 0c96ef56d3..6b8de05f73 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -51,7 +51,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 13da444051..4a701e7792 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -137,11 +137,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {row * stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {row * stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {col * stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {col * stride, 1_uz, stride}, layout); } }; diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp index 80b9724930..b19aff2081 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -44,10 +44,10 @@ using DeviceConvBwdWeightInstance = 128, // NPerBlock 4, // K0PerBlock 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXdl + 16, // NPerXdl + 4, // MXdlPerWave + 4, // NXdlPerWave S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder @@ -80,6 +80,11 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe int main(int argc, char* argv[]) { + if(ck::is_gfx11_supported()) + { + return 0; + } + ExecutionConfig config; ck::utils::conv::ConvParam conv_param = DefaultConvParam; diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index ce9f9b7032..ae5e3f36ad 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -65,7 +65,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| //######| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N| _M_N| _M| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<8, 32>, 8>; + < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 4, S<8, 32>, 4>; // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { @@ -154,8 +154,8 @@ void host_gemm_layernorm(Tensor& h_m_n, int main() { - // temp disable on gfx11 & gfx12 - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + // temp disable on gfx11 + if(ck::is_gfx11_supported()) { return 0; } diff --git a/example/22_cgemm/cgemm_xdl_bf16.cpp b/example/22_cgemm/cgemm_xdl_bf16.cpp index fa4482a984..716d36b487 100644 --- a/example/22_cgemm/cgemm_xdl_bf16.cpp +++ b/example/22_cgemm/cgemm_xdl_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -69,7 +69,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp index 89a581e865..2996d87b28 100644 --- a/example/22_cgemm/cgemm_xdl_fp16.cpp +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -47,10 +47,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 32, // index_t KPerBlock 8, // index_t AK1 8, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -68,7 +68,7 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) diff --git a/example/22_cgemm/cgemm_xdl_fp32.cpp b/example/22_cgemm/cgemm_xdl_fp32.cpp index cf96599599..45d23b4377 100644 --- a/example/22_cgemm/cgemm_xdl_fp32.cpp +++ b/example/22_cgemm/cgemm_xdl_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 16, // index_t KPerBlock 4, // index_t AK1 4, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -69,11 +69,16 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle S<1, 16, 1, 16>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + 2>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + bool do_verification = true; int init_method = 1; bool time_kernel = false; @@ -87,25 +92,25 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } } else { @@ -114,7 +119,7 @@ int main(int argc, char* argv[]) << "arg3: run kernel # of times (>1)\n" << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" << std::endl; - exit(0); + exit(1); } return !run_cgemm_xdl @@ -48,10 +48,10 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 64, // index_t KPerBlock 16, // index_t AK1 16, // index_t BK1 - 32, // index_t MPerXDL - 32, // index_t NPerXDL - 4, // index_t MXdlPerWave - 2, // index_t NXdlPerWave + 16, // index_t MPerXDL + 16, // index_t NPerXDL + 8, // index_t MXdlPerWave + 4, // index_t NXdlPerWave S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder @@ -68,8 +68,8 @@ using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_ 1, // index_t BBlockLdsExtraN 1, // index_t CShuffleMXdlPerWavePerShuffle 1, // index_t CShuffleNXdlPerWavePerShuffle - S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on int main(int argc, char* argv[]) @@ -87,25 +87,25 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc == 10) + else if(argc == 4 || argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); + if(argc == 10) + { + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } } else { @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) << "arg3: run kernel # of times (>1)\n" << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n" << std::endl; - exit(0); + exit(1); } return !run_cgemm_xdl #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index 548500518f..bd4be91103 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -68,10 +68,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 32, // KPerBlock 8, // AK1 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -89,11 +89,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8>, // CDEShuffleBlockTransferScalarPerVectors + S<4>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Intrawave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v3 // BlockGemmPipelineVersion >; #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp index d1985f9af5..80f6b1d663 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -51,9 +53,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp index 42171bcdb7..46e452005d 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp16int4_b_scale_v3.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp index a92a04dbe6..19a63ff841 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp32.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -50,9 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) +{ + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 0; + } + + return run_batched_gemm_example(argc, argv); +} diff --git a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp index 84f92eba8e..2a4610bf90 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp @@ -74,10 +74,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 64, // KPerBlock 16, // AK1 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 4, // NXdlPerWave S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder @@ -95,7 +95,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - S<8, 8, 1>, // CDEShuffleBlockTransferScalarPerVectors + S<4, 4, 1>, // CDEShuffleBlockTransferScalarPerVectors ck::BlockGemmPipelineScheduler::Interwave, // BlockGemmPipelineScheduler ck::BlockGemmPipelineVersion::v1, // BlockGemmPipelineVersion F8 // ComputeTypeA @@ -103,4 +103,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #include "run_batched_gemm_example_rowwise.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_rowwise_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_rowwise_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp index 5e82cfe324..4adb79ed29 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int4.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int4.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -96,4 +98,4 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD #define BUILD_INT4_EXAMPLE #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp index ad22227af5..06b8c3c0cd 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_int8.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_int8.cpp @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include #include @@ -48,9 +50,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on #include "run_batched_gemm_example.inc" -int main(int argc, char* argv[]) { return !run_batched_gemm_example(argc, argv); } +int main(int argc, char* argv[]) { return run_batched_gemm_example(argc, argv); } diff --git a/example/24_batched_gemm/run_batched_gemm_example.inc b/example/24_batched_gemm/run_batched_gemm_example.inc index 741512bf00..182ab8d967 100644 --- a/example/24_batched_gemm/run_batched_gemm_example.inc +++ b/example/24_batched_gemm/run_batched_gemm_example.inc @@ -1,3 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once @@ -59,11 +61,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; @@ -214,35 +218,37 @@ bool run_batched_gemm_example(int argc, char* argv[]) problem_size.batch_count = 2; - if(argc == 4) + if(argc == 1) + { + // use default case + } + else if(argc == 4 || argc == 8) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - } - else if(argc == 8) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - problem_size.M = std::stoi(argv[4]); - problem_size.N = std::stoi(argv[5]); - problem_size.K = std::stoi(argv[6]); - problem_size.batch_count = std::stoi(argv[7]); + if(argc == 8) + { + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + problem_size.batch_count = std::stoi(argv[7]); + } } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("optinal\n"); - printf("arg4-7: M = %d N = %d K = %d Batch = %d\n", - problem_size.M, - problem_size.N, - problem_size.K, - problem_size.batch_count); - exit(0); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("optional\n"); + printf("arg4-7: M, N, K, Batch\n"); + exit(1); } + printf("M = %d N = %d K = %d Batch = %d\n", + problem_size.M, + problem_size.N, + problem_size.K, + problem_size.batch_count); problem_size.stride_A = problem_size.K; problem_size.stride_B = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 3582bc5e33..5e56670fcf 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -137,11 +137,13 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto layout) { if constexpr(std::is_same_v) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; @@ -344,7 +346,7 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; - return true; + return false; } bool pass = true; @@ -521,6 +523,11 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return 1; + } + ProblemSize problem_size; ExecutionConfig config; @@ -533,30 +540,30 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) problem_size.batch_count = 2; - if(argc == 4) + if(argc == 1) { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); + // use default case } - else if(argc >= 7) + else if(argc == 4 || argc >= 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); - - problem_size.M = std::stoi(argv[4]); - problem_size.N = std::stoi(argv[5]); - problem_size.K = std::stoi(argv[6]); - - if(argc >= 8) + if(argc >= 7) { - problem_size.batch_count = std::stoi(argv[7]); - } + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); - if(argc >= 9) - { - problem_size.KBatch = std::stoi(argv[8]); + if(argc >= 8) + { + problem_size.batch_count = std::stoi(argv[7]); + } + + if(argc >= 9) + { + problem_size.KBatch = std::stoi(argv[8]); + } } } else @@ -564,7 +571,10 @@ bool run_batched_gemm_fp16_int4_b_scale_example(int argc, char* argv[]) printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); + printf("arg4-6: problem size (M, N, K)\n"); + printf("arg7: batch count\n"); + printf("arg8: KBatch\n"); + exit(1); } problem_size.stride_A = problem_size.K; diff --git a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc index 778be8ffd7..6ed0b23407 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_rowwise.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #pragma once @@ -64,11 +64,13 @@ bool run_batched_gemm_rowwise(const ProblemSize& problem_size, const ExecutionCo if(std::is_same::value) { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count_, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count_, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp index 420a7cf74f..4f4003809b 100644 --- a/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp +++ b/example/25_gemm_bias_e_permute/gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M2_N3_K1 using S = ck::Sequence; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -342,7 +345,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_gs_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G1_M3_N2_K1 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -189,7 +191,7 @@ int run_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; @@ -173,7 +175,7 @@ int run_contraction_scale_example(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 #include @@ -18,6 +18,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -304,10 +307,10 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides, Row{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Bypass{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); ck::index_t M_ = ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); @@ -416,9 +419,9 @@ int main(int argc, char* argv[]) const auto e_ms_ns_lengths = contraction_descs[i].e_ms_ns_lengths; const auto e_ms_ns_strides = contraction_descs[i].e_ms_ns_strides; - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); e_tensors_device[i]->FromDevice(e_device_tensors[i].mData.data()); diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index f556be887f..c4cb7a13a2 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -300,11 +303,11 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; @@ -396,7 +399,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -17,6 +17,9 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -54,7 +57,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceBatchedContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -247,11 +250,11 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); - Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); - Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides, Row{}); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -345,7 +348,8 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor c_ms_ns_host_result( + e_gs_ms_ns_lengths, e_gs_ms_ns_strides, Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + #include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc" int main(int argc, char* argv[]) diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc index 255a0cddaf..7a03e9cacf 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_example.inc @@ -110,11 +110,13 @@ bool run_batched_gemm_gemm_example(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc index 8ab47c2925..cea18459f4 100644 --- a/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc +++ b/example/31_batched_gemm_gemm/run_batched_gemm_gemm_wmma_cshuffle_v3.inc @@ -62,17 +62,19 @@ int run(int argc, char* argv[]) std::vector b1_g_o_n_lengths{G, O, N}; #ifdef CK_MHA_USE_RCCR_LAYOUT std::vector b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N] + auto b1_layout = Row{}; #else std::vector b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O] + auto b1_layout = Col{}; #endif std::vector c_g_m_o_lengths{G, M, O}; std::vector c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O] - Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides); - Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides); - Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides); - Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides); - Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides); + Tensor a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides, Row{}); + Tensor b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides, Row{}); + Tensor b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides, b1_layout); + Tensor c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); + Tensor c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides, Row{}); std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp index 5794924294..7738a6b6d4 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. /* Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o @@ -100,11 +100,11 @@ using DeviceGemmInstance = 8, // AK1 8, // BK1 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave + 16, // MPerXDL + 16, // NPerXDL + 2, // MXdlPerWave + 8, // NXdlPerWave + 4, // Gemm1NXdlPerWave S<4, 64, 1>, // ABlockTransfer S<1, 0, 2>, S<1, 0, 2>, @@ -129,7 +129,7 @@ using DeviceGemmInstance = 1, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + 4, // CShuffleBlockTransferScalarPerVector_NPerBlock MaskingSpec>; // MaskingSpecialization // Ref Gemm0: fp16 in, fp32 out diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc index 1514fc48b3..aa2a6b3b42 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm.inc @@ -111,12 +111,14 @@ int run(int argc, char* argv[]) if(std::is_same::value) { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); + std::vector({batch_stride, stride, 1}), + layout); } else { return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); + std::vector({batch_stride, 1, stride}), + layout); } }; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc index 2b02069e65..6175f0b5be 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +90,11 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Bypass{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Bypass{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Bypass{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc index e0ccb6dad1..db13e3b963 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -88,11 +92,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc index 0ad031cc71..1e4b52d4cf 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -113,11 +117,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -191,7 +214,7 @@ int run(int argc, char* argv[]) head_num * 2 * head_dim, head_dim, 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] - Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides, Bypass{}); // merge kv into a packed pointer send to device b0_gs_ns_ks.ForEach( [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index c693995140..874d987a1d 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -63,6 +67,19 @@ int run(int argc, char* argv[]) std::size_t flop = 0, num_byte = 0; + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; std::cout << "group count " << group_count << ". printing first 4 groups\n"; for(std::size_t i = 0; i < group_count; i++) { @@ -113,10 +130,14 @@ int run(int argc, char* argv[]) {}}); // acc1_biases_gs_ms_os_strides // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks(f_host_tensor_descriptor( + b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns(f_host_tensor_descriptor( + b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_device_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); int Batch = G0 * G1; flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; @@ -252,7 +273,8 @@ int run(int argc, char* argv[]) Tensor acc0_g_m_n({G0 * G1, M, N}); // scratch object after gemm0 Tensor a1_g_m_n({G0 * G1, M, N}); // scratch object after softmax Tensor c_g_m_o_host_result({G0 * G1, M, O}); // scratch object after gemm1 - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_host_result(f_host_tensor_descriptor( + c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); // permute a_gs_ms_ks.ForEach([&](auto& self, auto idx) { diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc index 7ac29f33ca..1c2a26d916 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc index fb9b1b0bd7..76f3ee756c 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -91,11 +95,30 @@ int run(int argc, char* argv[]) ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc index 2cb69380e5..86754927ed 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -1,6 +1,10 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + int run(int argc, char* argv[]) { bool do_verification = true; @@ -108,11 +112,30 @@ int run(int argc, char* argv[]) head_dim, 1}; // C layout [batch_size, head_num, sequence_length, head_dim] - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + auto f_host_tensor_descriptor = [](std::vector lens, + std::vector strides, + bool permute, + auto layout) { + if(permute) + { + return HostTensorDescriptor(lens, strides, Bypass{}); + } + else + { + return HostTensorDescriptor(lens, strides, layout); + } + }; + + Tensor a_gs_ms_ks( + f_host_tensor_descriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, input_permute, Row{})); + Tensor b0_gs_ns_ks( + f_host_tensor_descriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, input_permute, Row{})); + Tensor b1_gs_os_ns( + f_host_tensor_descriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, input_permute, Col{})); + Tensor c_gs_ms_os_host_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); + Tensor c_gs_ms_os_device_result( + f_host_tensor_descriptor(c_gs_ms_os_lengths, c_gs_ms_os_strides, output_permute, Row{})); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; @@ -186,7 +209,7 @@ int run(int argc, char* argv[]) head_num * 3 * head_dim, head_dim, 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] - Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides, Bypass{}); // merge qkv into a packed pointer send to device a_gs_ms_ks.ForEach( [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp index 7ceb1d09ef..1843198933 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp index b5aeff65d6..1e4398b9f6 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_bf16A_i8B.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp index cb84f2a416..d5acde139a 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ReduceDataType>; // clang-format on diff --git a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp index 2ab8f77dc4..bb3c23f060 100644 --- a/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp +++ b/example/35_splitK_gemm/gemm_xdl_splitk_reduce_multi_d_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" @@ -35,13 +35,13 @@ using DeviceGemmV2Instance = 256, 128, 128, 64, 8, 4, - 32, 32, - 2, 2, + 16, 16, + 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, - 1, 1, S<1, 32, 1, 8>, 8, + 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v2, ReduceDataType>; // clang-format on diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index 26a03f289d..ff7acea48d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -50,14 +50,14 @@ template<> struct emb_kernel { using kernel_type = DeviceInsta // clang-format on -int main() +int main(int argc, char* argv[]) { bool time_kernel = true; - constexpr auto num_rows = 65536; - constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; - // constexpr auto dims = ck::Sequence<256, 512>{}; - constexpr auto index_length = 2048; + ck::index_t num_rows = 65536; + constexpr auto dims = ck::Sequence<256, 512, 768, 1024, 1536, 2048, 4096, 8192>{}; + ck::index_t index_length = 2048; + ck::index_t dim_mask = 0xffff; constexpr AccDataType epsilon = 1e-4; auto f_host_tensor_desc_1d = [](std::size_t len_) { return HostTensorDescriptor({len_}); }; @@ -73,121 +73,143 @@ int main() BetaDataType, AccDataType, OutType>; + if(argc == 1) + { + // Use default value + } + else if(argc == 5) + { + time_kernel = std::stoi(argv[1]); + num_rows = std::stoi(argv[2]); + dim_mask = strtol(argv[3], nullptr, 0); + index_length = std::stoi(argv[4]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "arg1: time kernel (0=no, 1=yes)" << std::endl; + std::cout << "arg2-4: num_rows dim_mask index_length" << std::endl; + return 1; + } ck::static_for<0, dims.Size(), 1>{}([&](auto I) { - std::srand(std::time(nullptr)); - constexpr auto current_dim = dims.At(I); - Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); - Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); - - Tensor index_a(f_host_tensor_desc_1d(index_length)); - Tensor index_b(f_host_tensor_desc_1d(index_length)); - Tensor index_c(f_host_tensor_desc_1d(index_length)); - - Tensor gamma(f_host_tensor_desc_1d(current_dim)); - Tensor beta(f_host_tensor_desc_1d(current_dim)); - - Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); - - emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); - - gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); - DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); - DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); - - DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); - DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); - DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); - - DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); - DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); - - DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); - - emb_a_dev.ToDevice(emb_a.mData.data()); - emb_b_dev.ToDevice(emb_b.mData.data()); - emb_c_dev.ToDevice(emb_c.mData.data()); - - index_a_dev.ToDevice(index_a.mData.data()); - index_b_dev.ToDevice(index_b.mData.data()); - index_c_dev.ToDevice(index_c.mData.data()); - - gamma_dev.ToDevice(gamma.mData.data()); - beta_dev.ToDevice(beta.mData.data()); - - auto device_instance = typename emb_kernel::kernel_type{}; - auto argument_ptr = device_instance.MakeArgumentPointer( - out_dev.GetDeviceBuffer(), - {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, - {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, - gamma_dev.GetDeviceBuffer(), - beta_dev.GetDeviceBuffer(), - current_dim, - index_length, - epsilon, - EmbElementwiseOperation{}); - std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() - << std::endl - << std::flush; - - bool is_supported = device_instance.IsSupportedArgument(argument_ptr.get()); - - if(!is_supported) + if(dim_mask & (1 << I.value)) { - std::cout << "Runtime parameters are not supported" << std::endl; - return; + std::srand(std::time(nullptr)); + constexpr auto current_dim = dims.At(I); + Tensor emb_a(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_b(f_host_tensor_desc_2d(num_rows, current_dim)); + Tensor emb_c(f_host_tensor_desc_2d(num_rows, current_dim)); + + Tensor index_a(f_host_tensor_desc_1d(index_length)); + Tensor index_b(f_host_tensor_desc_1d(index_length)); + Tensor index_c(f_host_tensor_desc_1d(index_length)); + + Tensor gamma(f_host_tensor_desc_1d(current_dim)); + Tensor beta(f_host_tensor_desc_1d(current_dim)); + + Tensor out(f_host_tensor_desc_2d(index_length, current_dim)); + + emb_a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + emb_c.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + index_a.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_b.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + index_c.GenerateTensorValue(GeneratorTensor_2{0, num_rows}); + + gamma.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + beta.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem emb_a_dev(sizeof(EmbType) * emb_a.mDesc.GetElementSpaceSize()); + DeviceMem emb_b_dev(sizeof(EmbType) * emb_b.mDesc.GetElementSpaceSize()); + DeviceMem emb_c_dev(sizeof(EmbType) * emb_c.mDesc.GetElementSpaceSize()); + + DeviceMem index_a_dev(sizeof(IndexType) * index_a.mDesc.GetElementSpaceSize()); + DeviceMem index_b_dev(sizeof(IndexType) * index_b.mDesc.GetElementSpaceSize()); + DeviceMem index_c_dev(sizeof(IndexType) * index_c.mDesc.GetElementSpaceSize()); + + DeviceMem gamma_dev(sizeof(GammaDataType) * gamma.mDesc.GetElementSpaceSize()); + DeviceMem beta_dev(sizeof(BetaDataType) * beta.mDesc.GetElementSpaceSize()); + + DeviceMem out_dev(sizeof(OutType) * out.mDesc.GetElementSpaceSize()); + + emb_a_dev.ToDevice(emb_a.mData.data()); + emb_b_dev.ToDevice(emb_b.mData.data()); + emb_c_dev.ToDevice(emb_c.mData.data()); + + index_a_dev.ToDevice(index_a.mData.data()); + index_b_dev.ToDevice(index_b.mData.data()); + index_c_dev.ToDevice(index_c.mData.data()); + + gamma_dev.ToDevice(gamma.mData.data()); + beta_dev.ToDevice(beta.mData.data()); + + auto device_instance = typename emb_kernel::kernel_type{}; + auto argument_ptr = device_instance.MakeArgumentPointer( + out_dev.GetDeviceBuffer(), + {ck::type_convert(emb_a_dev.GetDeviceBuffer()), + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + {ck::type_convert(index_a_dev.GetDeviceBuffer()), + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, + gamma_dev.GetDeviceBuffer(), + beta_dev.GetDeviceBuffer(), + current_dim, + index_length, + epsilon, + EmbElementwiseOperation{}); + std::cout << "Dim:" << current_dim << ", kernel:" << device_instance.GetTypeString() + << std::endl + << std::flush; + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cerr << device_instance.GetTypeString() << " does not support this problem" + << std::endl; + return; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + float time_ms = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + bool pass = true; + { + Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); + ReferenceInstance ref; + auto ref_argument = ref.MakeArgument(out, + emb_a, + emb_b, + emb_c, + index_a, + index_b, + index_c, + gamma, + beta, + num_rows, + current_dim, + index_length, + epsilon); + auto ref_invoker = ref.MakeInvoker(); + ref_invoker.Run(ref_argument); + + out_dev.FromDevice(out_from_dev.mData.data()); + pass &= + ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); + } + + double total_read = current_dim * index_length * 3 * sizeof(EmbType) + + current_dim * sizeof(GammaDataType) + + current_dim * sizeof(BetaDataType); + double total_write = current_dim * index_length * sizeof(OutType); + double gbps = (total_read + total_write) / time_ms / 1e6; + + std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms + << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl + << std::flush; } - - auto invoker_ptr = device_instance.MakeInvokerPointer(); - float time_ms = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - bool pass = true; - { - Tensor out_from_dev(f_host_tensor_desc_2d(index_length, current_dim)); - ReferenceInstance ref; - auto ref_argument = ref.MakeArgument(out, - emb_a, - emb_b, - emb_c, - index_a, - index_b, - index_c, - gamma, - beta, - num_rows, - current_dim, - index_length, - epsilon); - auto ref_invoker = ref.MakeInvoker(); - ref_invoker.Run(ref_argument); - - out_dev.FromDevice(out_from_dev.mData.data()); - pass &= ck::utils::check_err(out_from_dev, out, "Error: Incorrect results", 1e-3, 1e-3); - } - - double total_read = current_dim * index_length * 3 * sizeof(EmbType) + - current_dim * sizeof(GammaDataType) + - current_dim * sizeof(BetaDataType); - double total_write = current_dim * index_length * sizeof(OutType); - double gbps = (total_read + total_write) / time_ms / 1e6; - - std::cout << ", total bytes:" << (total_read + total_write) << ", time:" << time_ms - << ", gbps:" << gbps << ", valid:" << (pass ? "y" : "n") << std::endl - << std::flush; }); return 0; diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp index 904ff761fd..4934f74393 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp +++ b/example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp @@ -321,11 +321,13 @@ int main(int argc, char* argv[]) if(std::is_same::value) { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, stride, 1_uz}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, stride, 1_uz}, layout); } else { - return HostTensorDescriptor({batch_count, row, col}, {batch_stride, 1_uz, stride}); + return HostTensorDescriptor( + {batch_count, row, col}, {batch_stride, 1_uz, stride}, layout); } }; diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 0f0b120cbc..80d56cd781 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -206,7 +206,8 @@ int run_grouped_conv_bwd_data_bias_relu_example(int argc, char* argv[]) 1, // c 0, // hi 0 // wi - }); + }, + ctc::GNCHW{}); // input image: GNHWC const auto in_g_n_c_wis_desc = diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 30e0791ebf..3c089688cf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -214,7 +214,8 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_ 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto requant_scale_g_k_desc = bias_g_k_desc; diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 32fd435e00..ed7886e76b 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -201,7 +201,8 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el 1, // k 0, // ho 0 // wo - }); + }, + BiasLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 362d90b4c1..12fdf425bf 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -203,7 +203,8 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme 1, // k 0, // ho 0 // wo - }); + }, + RequantScaleLayout{}); const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); diff --git a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp index ebba88cf41..873982227b 100644 --- a/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp +++ b/example/43_splitk_gemm_bias_e_permute/splitk_gemm_bias_e_permute_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, F16, F16, F32, F16, DsDataType, F16, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -22,6 +22,9 @@ using S = ck::Sequence; using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Add = ck::tensor_operation::element_wise::Add; @@ -53,7 +56,7 @@ using DeviceOpInstanceKKNN = ck::tensor_operation::device:: //############################################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Spacialization| Spacialization| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>; + DeviceSplitKContractionMultipleD_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, ABSpec, ABSpec, DESpec, 1, 256, 256, 128, 32, 4, 4, 16, 16, 8, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 2>; // clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -250,19 +253,24 @@ int main(int argc, char* argv[]) Tensor a_gs_ms_ks( std::vector(a_gs_ms_ks_lengths.begin(), a_gs_ms_ks_lengths.end()), - std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end())); + std::vector(a_gs_ms_ks_strides.begin(), a_gs_ms_ks_strides.end()), + Row{}); Tensor b_gs_ns_ks( std::vector(b_gs_ns_ks_lengths.begin(), b_gs_ns_ks_lengths.end()), - std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end())); + std::vector(b_gs_ns_ks_strides.begin(), b_gs_ns_ks_strides.end()), + Row{}); Tensor d_gs_ms_ns( std::vector(d_gs_ms_ns_lengths.begin(), d_gs_ms_ns_lengths.end()), - std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end())); + std::vector(d_gs_ms_ns_strides.begin(), d_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); Tensor e_gs_ms_ns_device_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; @@ -372,7 +380,8 @@ int main(int argc, char* argv[]) { Tensor c_ms_ns_host_result( std::vector(e_gs_ms_ns_lengths.begin(), e_gs_ms_ns_lengths.end()), - std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end())); + std::vector(e_gs_ms_ns_strides.begin(), e_gs_ms_ns_strides.end()), + Bypass{}); using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1 #include @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -49,6 +51,8 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + std::vector nchw = {16, 128, 32, 64}; + if(argc == 1) { // use default @@ -58,14 +62,23 @@ int main(int argc, char* argv[]) do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + nchw[0] = std::stoi(argv[3]); + nchw[1] = std::stoi(argv[4]); + nchw[2] = std::stoi(argv[5]); + nchw[3] = std::stoi(argv[6]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); + printf("arg3-6: N, C, H, W (default 16, 128, 32, 64)\n"); + exit(1); } - std::vector nchw = {16, 128, 32, 64}; std::array ab_lengths; std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), static_cast(nchw[2] * nchw[3]), @@ -73,11 +86,11 @@ int main(int argc, char* argv[]) 1}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 2> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 2> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -134,7 +147,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 9e92543252..2d689648f2 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple @@ -72,9 +74,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[3])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -117,7 +119,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 88c23b5f40..6e70a306d3 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -23,6 +23,8 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -137,7 +139,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 1185b5a3ca..632d88e88a 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -128,7 +131,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 28a3dbc44c..bd54f1c19c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -22,6 +22,8 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +78,9 @@ int main(int argc, char* argv[]) static_cast(nhwc[0] * nhwc[1])}; ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 1.f; auto i = 0; @@ -139,7 +141,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index 14d1d96165..9621d591a9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F32; using BDataType = F32; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -76,9 +79,9 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + std::array, 1> as = {Tensor(ab_lengths, a_strides, NchwLayout{})}; Tensor& a = as[0]; - Tensor b(ab_lengths, b_strides); + Tensor b(ab_lengths, b_strides, NhwcLayout{}); float scale = 2.f; a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -127,7 +130,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, b_strides); + Tensor host_b(ab_lengths, b_strides, NhwcLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 0619cc7139..1ce797e4dd 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -119,6 +119,11 @@ int main(int argc, char* argv[]) bool do_verification = true; bool time_kernel = true; + const float scale = 2.f; + + ck::index_t M = 1024; + ck::index_t K = 1024; + if(argc == 1) { // use default @@ -128,22 +133,19 @@ int main(int argc, char* argv[]) do_verification = std::stoi(argv[1]); time_kernel = std::stoi(argv[2]); } + else if(argc == 5) + { + do_verification = std::stoi(argv[1]); + time_kernel = std::stoi(argv[2]); + M = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - const float scale = 2.f; - - ck::index_t M = 1024; - ck::index_t K = 1024; - - if(argc == 3) - { - M = std::stoi(argv[1]); - K = std::stoi(argv[2]); + printf("arg3-4: M(default=1024), K(default=1024)\n"); + exit(1); } std::array dims = {M, K}; diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp index 2583f1cb5e..be4014f636 100644 --- a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -22,6 +22,9 @@ using F32 = float; using ADataType = F16; using BDataType = F16; +using NchwLayout = ck::tensor_layout::convolution::NCHW; +using NhwcLayout = ck::tensor_layout::convolution::NHWC; + using UnaryScale = ck::tensor_operation::element_wise::Scale; using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; using UnaryScaleSquare = @@ -78,13 +81,13 @@ int main(int argc, char* argv[]) ck::ranges::copy(nchw, ab_lengths.begin()); - std::array, 3> as = {Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides), - Tensor(ab_lengths, ab_strides)}; + std::array, 3> as = {Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{}), + Tensor(ab_lengths, ab_strides, NchwLayout{})}; Tensor& a0 = as[0]; Tensor& a1 = as[1]; Tensor& a2 = as[2]; - Tensor b(ab_lengths, ab_strides); + Tensor b(ab_lengths, ab_strides, NchwLayout{}); float alpha = 3.f; float beta = 2.f; float gamma = 4.f; @@ -149,7 +152,7 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor host_b(ab_lengths, ab_strides); + Tensor host_b(ab_lengths, ab_strides, NchwLayout{}); using ReferenceElementwiseInstance = ck::tensor_operation::host:: ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; auto ref_elementwise = ReferenceElementwiseInstance{}; diff --git a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp index 51006e676b..8064809123 100644 --- a/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp +++ b/example/45_elementwise_normalization/elementwise_layernorm_blockwise.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -98,8 +98,23 @@ int main(int argc, char* argv[]) exit(0); } - ck::index_t M = 48 * 256; - ck::index_t N = 1024; + ck::index_t M = 48 * 256; + ck::index_t N = 1024; + if(argc == 1) + { + // use default case + } + else if(argc == 3) + { + M = std::stoi(argv[1]); + N = std::stoi(argv[2]); + } + else + { + std::cerr << "arg1 to 2: M, N" << std::endl; + return 1; + } + ck::index_t Stride = N; auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { diff --git a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp index 56417b101d..4d73f0c35f 100644 --- a/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp +++ b/example/46_gemm_add_multiply/gemm_add_multiply_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" @@ -31,7 +31,7 @@ using DeviceOpInstance = ck::tensor_operation::device:: //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>; + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, DsLayout, Row, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 16, 16, 8, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfig& config) { using namespace ck::literals; - auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = problem_size; + ProblemSize ps = + problem_size; // make mutable copy because default stride values of 0 need to be updated + auto& [M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE] = ps; - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, int& stride, auto layout) { + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); @@ -123,7 +131,16 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + if(std::is_same_v, ck::half_t> && + std::is_same_v, ck::half_t>) + { + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-3, 1e-3); + } + else + { + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + } } return true; diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp index 1b24bd3bba..3e69caf51e 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp @@ -18,6 +18,10 @@ #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -220,12 +224,12 @@ int main(int argc, char* argv[]) std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; - Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); - Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); - Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); - Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); - Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides, Row{}); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides, Row{}); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides, Col{}); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides, Row{}); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides, Row{}); std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; diff --git a/example/48_pool3d_fwd/pool3d_fwd_common.hpp b/example/48_pool3d_fwd/pool3d_fwd_common.hpp index 788f38ec52..ef64dd167d 100644 --- a/example/48_pool3d_fwd/pool3d_fwd_common.hpp +++ b/example/48_pool3d_fwd/pool3d_fwd_common.hpp @@ -48,15 +48,16 @@ HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_, if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Pool3d_fwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz}, layout); } else if constexpr(ck::is_same::value) { - return HostTensorDescriptor({N_, C_, D, H, W}, - {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}); + return HostTensorDescriptor( + {N_, C_, D, H, W}, {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_}, layout); } throw std::runtime_error("Avgpool3d_bwd: problem with layout. "); - return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}); + return HostTensorDescriptor({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, layout); }; template ; // DBetaDstVectorSize -int main() +int main(int argc, char* argv[]) { bool time_kernel = false; @@ -110,6 +110,25 @@ int main() ck::index_t G = 32; ck::index_t C = 64; + if(argc == 1) + { + // use default case + } + else if(argc == 6) + { + N = std::stoi(argv[1]); + H = std::stoi(argv[2]); + W = std::stoi(argv[3]); + G = std::stoi(argv[4]); + C = std::stoi(argv[5]); + } + else + { + std::cerr << "arg1 to 5: N, H, W, G, C" << std::endl; + + return 1; + } + Tensor dy({N, H, W, G, C}); Tensor x({N, H, W, G, C}); Tensor gamma({G, C}); diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp index 5220a4616e..405eac7df1 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_bias_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{StrideD}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp index b424fdaf45..50e670bdf3 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fastgelu_bf16_i8.cpp @@ -81,10 +81,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -120,23 +121,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) N, K, std::array{StrideA}, - std::array{StrideB, 0}, + std::array{StrideB, StrideB1}, std::array{}, StrideE, a_element_op, diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp index 03a74c04b7..50e1c21c8f 100644 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp +++ b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_multiply_bias_fastgelu_bf16_i8.cpp @@ -80,10 +80,11 @@ int main(int argc, char* argv[]) ck::index_t N = 768; ck::index_t K = 6144; - ck::index_t StrideA = K; - ck::index_t StrideB = N; - ck::index_t StrideD = 0; - ck::index_t StrideE = N; + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideB1 = 0; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; if(argc == 1) { @@ -119,23 +120,31 @@ int main(int argc, char* argv[]) exit(0); } - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + ck::index_t& stride, + auto layout) { + using namespace ck::literals; - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; + if(std::is_same::value) + { + auto desc = HostTensorDescriptor({row, col}, {static_cast(stride), 1_uz}); + if(stride <= 0) + stride = desc.GetStrides()[0]; + return desc; + } + else + { + auto desc = HostTensorDescriptor({row, col}, {1_uz, static_cast(stride)}); + if(stride <= 0) + stride = desc.GetStrides()[1]; + return desc; + } + }; Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); - Tensor b1_k_n(f_host_tensor_descriptor(K, N, 0, B1Layout{})); + Tensor b1_k_n(f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); @@ -196,7 +205,7 @@ int main(int argc, char* argv[]) K, std::array{StrideA}, std::array{StrideB}, - std::array{0, StrideD}, + std::array{StrideB1, StrideD}, StrideE, a_element_op, b_element_op, @@ -261,7 +270,7 @@ int main(int argc, char* argv[]) { for(int n = 0; n < N; ++n) { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(0, n), d_m_n(m, n)); + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), b1_k_n(m, n), d_m_n(m, n)); } } diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 90e14de59c..a9a30b4c27 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -160,12 +163,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides, Row{}); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -264,9 +267,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -299,7 +302,6 @@ int main(int argc, char* argv[]) auto ref_op = ReferenceOpInstance{}; auto ref_invoker = ref_op.MakeInvoker(); - Tensor empty_tensor(std::vector{}, std::vector{}); auto ref_argument = ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index ec1b2d6018..4f7414abfa 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -19,6 +19,9 @@ #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/numeric.hpp" +using Row = ck::tensor_layout::gemm::RowMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + template using S = ck::Sequence; @@ -140,12 +143,12 @@ int main(int argc, char* argv[]) exit(0); } - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); - Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); + Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides, Bypass{}); + Tensor b0_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); + Tensor b1_ns_ks(b1_ns_ks_lengths, b1_ns_ks_strides, Row{}); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; @@ -246,9 +249,9 @@ int main(int argc, char* argv[]) if(do_verification) { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides, Row{}); - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); + Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides, Row{}); for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) { @@ -266,7 +269,7 @@ int main(int argc, char* argv[]) } } - Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides); + Tensor b_ns_ks(b0_ns_ks_lengths, b0_ns_ks_strides, Row{}); for(size_t n0 = 0; n0 < b_ns_ks.mDesc.GetLengths()[0]; ++n0) { diff --git a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp index 2afe01f02d..0a802ee27d 100644 --- a/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -130,11 +130,12 @@ bool run_grouped_conv(bool do_verification, // Fill other lenghts than G,K with 1 and strides with 0 bias_g_k_lengths.fill(1); bias_g_k_strides.fill(0); - bias_g_k_lengths[0] = G; - bias_g_k_lengths[2] = K; - bias_g_k_strides[0] = K; // stride to G - bias_g_k_strides[2] = 1; // stride to K - const auto broadcasted_bias_desc = HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides); + bias_g_k_lengths[0] = G; + bias_g_k_lengths[2] = K; + bias_g_k_strides[0] = K; // stride to G + bias_g_k_strides[2] = 1; // stride to K + const auto broadcasted_bias_desc = + HostTensorDescriptor(bias_g_k_lengths, bias_g_k_strides, BiasLayout{}); // y = relu ( alpha1 * conv(x) + alpha2 * z + bias ) Tensor in(in_g_n_c_wis_desc); diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc index dc2bdc18f0..41c8c42bac 100644 --- a/example/64_fpAintB_gemm/run_gemm_example.inc +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -28,7 +28,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] - Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + Tensor scale_k_n( + HostTensorDescriptor({K, N}, {0, 1_uz}, ck::tensor_layout::BypassLayoutVerification())); switch(config.init_method) { diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp index 53963fc514..8b8cee9e52 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -241,6 +241,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -285,8 +307,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -308,7 +328,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index 7a2d0153d9..8da49ef85d 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -162,6 +162,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -216,7 +238,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{StrideD, StrideD}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp index 5aa978fbf0..3b21f95119 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_ab_scale.cpp @@ -91,6 +91,8 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; + ck::index_t KBatch = 1; + if(argc == 1) { // use default case @@ -101,7 +103,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); time_kernel = std::stoi(argv[3]); } - else if(argc == 8) + else if(argc == 8 || argc == 9) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -113,6 +115,11 @@ int main(int argc, char* argv[]) flush_cache = std::stoi(argv[7]); + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + StrideA = K; StrideB = K; StrideE = N; @@ -124,6 +131,7 @@ int main(int argc, char* argv[]) printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 6: M, N, K\n"); printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); exit(0); } @@ -233,9 +241,9 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{}, e_device_buf.GetDeviceBuffer(), @@ -251,6 +259,7 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, cde_element_op); + argument.KBatch = KBatch; if(!device_op.IsSupportedArgument(argument)) { diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index fe1eca51b0..3ee4955ae4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -251,6 +251,28 @@ int main(int argc, char* argv[]) Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, A0Layout{}, StrideA); + StrideB = get_stride(b0_k_n, B0Layout{}, StrideB); + ck::index_t StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD); + ck::index_t StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; @@ -295,8 +317,6 @@ int main(int argc, char* argv[]) constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -318,7 +338,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0}, + std::array{StrideD0, StrideD1}, StrideE, KBatch, a_element_op, diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 52ba3416a0..72ea7f1cb6 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -287,15 +287,18 @@ int main(int argc, char* argv[]) } } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; @@ -422,7 +425,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( - {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; @@ -463,7 +467,7 @@ int main(int argc, char* argv[]) Tensor b_e_n_k({experts, K, N * 2}); e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); // handle scale before ref. for(int t = 0; t < tokens; ++t) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 92a0cd9e5c..5e306ac6dd 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -264,15 +264,18 @@ int main(int argc, char* argv[]) } Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); Tensor d1_e_n( HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); - Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; @@ -488,7 +491,7 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d0_t_n( - HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0})); + HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}, Bypass{})); Tensor d1_e_n( HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 354957c0d1..cc42c4b815 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -292,17 +292,19 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + Scale_Block_K - 1) / Scale_Block_K}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k(HostTensorDescriptor( {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, - {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 6ca7d67f53..29e758f9d4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -29,8 +29,9 @@ using F16 = ck::half_t; using F8 = ck::f8_t; using F32 = float; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; using A0DataType = F8; using B0DataType = I4; @@ -239,10 +240,10 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); + Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}, Bypass{})); Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/66_complex_contraction_bilinear/common_instances.hpp b/example/66_complex_contraction_bilinear/common_instances.hpp index 480ca5a0af..ed1c1dc303 100644 --- a/example/66_complex_contraction_bilinear/common_instances.hpp +++ b/example/66_complex_contraction_bilinear/common_instances.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -37,7 +37,7 @@ using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>; + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 2, ComputeDataType>; // clang-format on template a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // For Imaginary Part of Complex Tensor - Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides); + Tensor a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides, DefaultLayout{}); + Tensor b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides, DefaultLayout{}); + Tensor d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides, DefaultLayout{}); - Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); // Intermediate E tensor Definition - Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl; std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl; @@ -349,8 +350,10 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[]) if(do_verification) { // Real Part Verification - Tensor c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_re( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_re1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); using ReferenceOpInstance = ck::tensor_operation::host::ReferenceContraction_M2_N2_K2 c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides); - Tensor c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides); + Tensor c_ms_ns_host_result_img( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); + Tensor c_ms_ns_host_result_img1( + e_ms_ns_lengths, e_ms_ns_strides, DefaultLayout{}); auto ref_argument_img = ref_op.MakeArgument( a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op); diff --git a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp index aaf0cb3891..69c0d6558f 100644 --- a/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp +++ b/example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp @@ -269,10 +269,12 @@ int main(int argc, char* argv[]) Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -281,12 +283,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -480,7 +483,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -278,12 +280,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -477,7 +480,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_k_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor a1_t_k(HostTensorDescriptor( {tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -310,12 +313,13 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2}, - {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_k_n_host_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); Tensor e_t_k_n_device_result( - HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); + HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}, Row{})); e_t_k_n_device_result.SetZero(); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; @@ -506,7 +510,7 @@ int main(int argc, char* argv[]) { invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}); + Tensor c_t_k_n({tokens, topk, N}, {topk * N, N, 1}, Row{}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeMXGemm1 a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp index 829bf9af24..5bb6454d2a 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp @@ -268,16 +268,18 @@ int main(int argc, char* argv[]) } } - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -286,7 +288,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp index efbd0f0c03..333f8a3d52 100644 --- a/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp @@ -303,16 +303,18 @@ int main(int argc, char* argv[]) expert_ids.savetxt("expert_ids.txt", "int"); sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}, Row{})); Tensor a1_t_k_k( HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize}, - {(topk * Scale_Stride_AM), Scale_Stride_AM, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + {(topk * Scale_Stride_AM), Scale_Stride_AM, 1}, + Row{})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); Tensor b1_e_n_k( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {(N * Scale_Stride_BN), 1, Scale_Stride_BN})); + {(N * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); // B preshuffle - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}, Col{})); // A, B Scale preshuffle Tensor a_scale_sorted(HostTensorDescriptor( @@ -321,7 +323,8 @@ int main(int argc, char* argv[]) {sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1})); Tensor b_scale_preshuffled( HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N}, - {N * Scale_Stride_BN, 1, Scale_Stride_BN})); + {N * Scale_Stride_BN, 1, Scale_Stride_BN}, + Col{})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1})); Tensor e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1})); diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index a2d7fe4aaf..b8ca26193d 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -47,7 +47,7 @@ set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 - --optdim 32,64,128,256 + --optdim 32,64,96,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 7f55d7412f..2b872cb9b5 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -36,6 +36,13 @@ args: total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode) -s_k seqlen_k (including new key/value), -1 means equal to s (default:-1) + also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode) + -s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1) + Provide positive strides per-batch to simulate physical padding on Q + -s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1) + for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride + along seqlen, instead of packed, same as xformer kv_padding, + must be greater than or equal to s_k -d head dim for q, k (default:128) -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) @@ -76,11 +83,20 @@ args: -repeat number of iterations to benchmark the kernel (default:20) -json 0: No Json, 1: Dump Results in Json format (default:0) -jsonfile json file name to dump results (default:fmha_fwd.json) + -q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override +-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"") + Comma-separated list of length 'b'. If empty, no override ``` Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case +## Padding Examples +Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively. + +Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively. + ## support features Currently we are still in rapid development stage, so more features/optimizations will be coming soon. @@ -128,6 +144,15 @@ Note FA use bottom-right by default to express swa case, here we require you exp ### dropout TBD +### sequence padding and variable length support +We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths. + +**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment. + +**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste. + +Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios. + ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 802c9e51d7..81d34484a5 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { + "fp32" : "FmhaFwdFp32", "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", @@ -12,6 +13,7 @@ FWD_DTYPE_MAP = { } BWD_DTYPE_MAP = { + "fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16" } diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 0d8f366d8a..e2f69fa49a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 4df99a9a10..059be0e490 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape; -using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits; using fmha_mask_{F_idx} = {F_mask}; using fmha_dropout_{F_idx} = {F_dropout}; @@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType, false, - {F_dvpad}>>; + ({F_dvpad} > 0)>>; using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType, false, - {F_dpad}>>; + ({F_dpad} > 0)>>; using fmha_bwd_dq_dk_dv_kernel_{F_idx} = ck_tile::FmhaBwdDQDKDVKernel; + using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>; using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>; + using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>; r = fmha_bwd_>(s, a); return r; }} @@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel: F_hdim : int # hdim F_dtype : str # data type F_tile : FmhaBwdDQDKDVTileSize - F_dpad : str # - F_dvpad : str # + F_dpad : Literal[0, 8 ,1] + F_dvpad : Literal[0, 8 ,1] F_bias : str # F_dbias : str # F_dropout : str # @@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel: F_wm1 = self.F_tile.F_wm1, F_wn1 = self.F_tile.F_wn1, F_wk1 = self.F_tile.F_wk1, - F_dpad = BOOL_MAP[self.F_dpad], - F_dvpad = BOOL_MAP[self.F_dvpad], + F_dpad = self.F_dpad, + F_dvpad = self.F_dvpad, F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = DROPOUT_MAP[self.F_dropout], @@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel: def name(self) -> str: def pad_name() -> str: n = '' - if self.F_dpad == 't' : n += 'd' - if self.F_dvpad == 't' : n += 'dv' + if self.F_dpad : n += f'd{self.F_dpad}' + if self.F_dvpad : n += f'dv{self.F_dvpad}' if n != '' : n = 'p' + n return n pn = pad_name() @@ -376,18 +370,30 @@ class FmhaBwdDQDKDVKernel: # TODO: design a more practical way to do it # this is current supported tile size. def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]: - if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': + if dtype == 'fp32' and tr_load == 'f': + return [ + # bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv, + FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1), + ] + elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f': return [ FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), # FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), ] elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't': return [ + FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1), FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + + # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32), + FmhaBwdDQDKDVTileSize( 32, 16, 64, 32, 64, 32, 16, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 32), # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16), FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16), ] @@ -621,8 +627,8 @@ class FmhaBwdApiTrait: dbias : str dropout : str spad1d : str # spad for 1d kernels (dot/convert) - dpad : str - dvpad : str + dpad : Literal[0, 1, 8] + dvpad : Literal[0, 1, 8] deterministic : str mask_impl : str tr_load : str @@ -651,13 +657,13 @@ class FmhaBwdApiTrait: @property def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' + if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0' + else: return f'a.hdim_q % {self.dpad} == 0' @property def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' + if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0' + else: return f'a.hdim_v % {self.dvpad} == 0' @property def extra_cond(self) -> str: @@ -677,8 +683,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dvpad = 't' if self.dvpad else 'f' return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d, - F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) @property def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: @@ -693,8 +700,9 @@ class FmhaBwdApiTrait: def get_occupancy(dtype, hdim): return 2 + F_dpad = 't' if self.dpad else 'f' return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, - F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad, + F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0) @@ -720,7 +728,7 @@ class FmhaBwdApiPool: F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype], - F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad, F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q, F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond, F_convert_dq_bn0=trait.convert_dq_bn0) @@ -793,7 +801,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]): tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load) - for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)): + dpad_options = itertools.product(*([[0, 8, 1]] * 2)) + tf = ["t", "f"] + for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product( + tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf): assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize" hdim = tile.F_bhdq if (mode == "group") and (spad1d == "f"): @@ -804,8 +815,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm continue if ("wg32" in dropout): continue - if tr_load == "t" and (dpad == "t" or dvpad == "t"): - continue # tr_load cannot work with dpad or dvpad + if tr_load == "t": + # tr_load can only work with 8 pad + if dpad != dvpad or dpad == 1: + continue + else: # tr_load == "f" + # do not generate instance with only 1 of dpad/dvpad being 8 + if dpad != dvpad and dpad == 8: + continue if optdim_list != [-1]: if hdim not in optdim_list: continue @@ -861,6 +878,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= dpad == dvpad + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [64, 128] + cond &= dpad == dvpad + cond &= mode == 'batch' + cond &= bias == 'no' + cond &= dropout == 'no' + cond &= mask == 's_no' + cond &= deterministic == "f" + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + gen_dot_do_o[t.dot_do_o_kernel] = True gen_dq_dk_dv[t.dq_dk_dv_kernel] = True if not t.convert_dq_kernel.disabled: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index cfb96b7d53..533f7f2f23 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -25,6 +25,7 @@ DTYPE_BITS = { K0_MAX_SUBMAX_MAP = { 32 : 32, + 48 : 48, 64 : 64, 96 : 128, 128: 128, @@ -164,7 +165,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} return r; @@ -249,9 +250,8 @@ class FmhaFwdApiTrait: else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - @property - def seqtune(self) -> str: - if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true + def seqtune(self, max_bm0 : int) -> str: + if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/' else: return f'a.seqlen_q <= {self.bm0}' @@ -259,11 +259,11 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) - else : return f'a.seqlen_k % {self.bn0} == 0' + else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)' elif self.pipeline_tag == 'qr_async_trload': if self.skpad == 't' : return 'true' else: return 'true' @@ -386,6 +386,7 @@ class FmhaFwdApiPool: per_hdim_case=str() for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load] + max_bm0 = max((t.bm0 for t in traits), default=0) inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -393,7 +394,7 @@ class FmhaFwdApiPool: F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load], - F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, @@ -534,7 +535,20 @@ class KernelComponentFactory: # this is current supported tile size per hdim @staticmethod def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]: - if dtype == 'fp16' or dtype == 'bf16': + if dtype == 'fp32': + return { + # bm0, bn0, bk0, bn1, bk1, + ( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + ( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + (256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)], + } + elif dtype == 'fp16' or dtype == 'bf16': return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), @@ -572,7 +586,13 @@ class KernelComponentFactory: # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? pipelines = [] - if dtype in ['fp16', 'bf16']: + if dtype in ['fp32']: + squant = 'f' + for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) + elif dtype in ['fp16', 'bf16']: squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: @@ -588,7 +608,7 @@ class KernelComponentFactory: else: pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) - if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f": + if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and skip == "f": pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't')) pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": @@ -626,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + for tile, next_tile in zip(tiles, tiles[1:]): + assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0' for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': @@ -635,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue - if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 - continue - if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): - continue + if dtype != 'fp32': + if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)): + # non qr_async_trload only support km0=128 tile size when hdim is not 128 + # non qr_async only support kn0=128 tile size when hdim is 128 + continue + if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])): + continue # logits_soft_cap is only allowed if no bias if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): continue @@ -710,6 +733,31 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only, all variations + if receipt == 800: + cond = dtype == 'fp32' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + if not cond: + continue + # fp32 only, minimal set of parameters + elif receipt == 801: + cond = dtype == 'fp32' + cond &= hdim in [48, 128] + cond &= mode == 'batch' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' + cond &= pipeline.F_skip == 'f' + cond &= pipeline.F_logits == 'f' + cond &= pipeline.F_mask == 's_no' + if not cond: + continue + else: + # Don't build fp32 by default + if dtype == 'fp32': + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 0ebeaddf9c..38491b56c4 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool: per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) @dataclass @@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op cond &= pipeline.F_vlayout == 'row' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index cee1505486..281357ef1e 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -768,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt cond &= pipeline.F_squant == 'f' if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) @@ -834,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim cond = dtype in ['fp16', 'bf16'] if not cond: continue + + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index df6b422981..3624b7b387 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if not cond: continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == 'fp32' + if not cond: + continue + api_pool.register_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index e0e1fba668..73b3c1e619 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[]) "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") - .insert("prec", "fp16", "data type. fp16 or bf16") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -159,7 +159,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == bwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == bwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index 91cb9f55be..c27a5ce1ae 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[]) "0", "seqlen_k for new key/value, 0 means not to use this at all; " "-1 to choose s_knew in [1, s] randomly.") + .insert("s_qpad", + "-1", + "seqlen_q stride between 2 batches (group-mode optional).\n" + "Provide positive strides per-batch to simulate physical padding on Q.") .insert("s_kpad", "-1", "seqlen_k stride between 2 batches, currently used in group-mode only\n" @@ -63,7 +67,7 @@ auto create_args(int argc, char* argv[]) "n or 0, no bias\n" "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" "a(libi) or 2, alibi with 1*h. a:1, b*h") - .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8") .insert("mask", "0", "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" @@ -107,7 +111,15 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "fmha_fwd.json", "json file name to dump results"); + .insert("jsonfile", "fmha_fwd.json", "json file name to dump results") + .insert("q_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override.") + .insert("kv_eff_lens", + "", + "Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n" + "Comma-separated list of length 'b'. If empty, no override."); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -127,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew"); auto seqlen_kpads = arg_parser.get_int_vec("s_kpad"); + auto seqlen_qpads = arg_parser.get_int_vec("s_qpad"); + auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens"); + auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens"); ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); bool i_perm = arg_parser.get_bool("iperm"); bool o_perm = arg_parser.get_bool("operm"); @@ -174,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser) hdim_q, hdim_v, seqlen_knew, + seqlen_qpads, seqlen_kpads, + q_eff_lens_per_batch, + kv_eff_lens_per_batch, rotary_dim, i_perm, o_perm, @@ -209,7 +227,11 @@ int main(int argc, char* argv[]) return -1; const std::string data_type = arg_parser.get_str("prec"); - if(data_type == "fp16") + if(data_type == "fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp16") { return run(arg_parser) == fwd_result::success ? 0 : -2; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index 569c98a458..7ddb65a2db 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair get_query_shape() const @@ -172,6 +183,8 @@ struct Problem mask_info mask; TensorLayout input_layout; TensorLayout output_layout; + std::vector q_eff_lens; + std::vector kv_eff_lens; }; struct RunConfig @@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) q_buf.ToDevice(q.data()); k_buf.ToDevice(k.data()); v_buf.ToDevice(v.data()); + // Ensure output buffer is zero-initialized so padded regions compare cleanly + o_buf.SetZero(); - ck_tile::fmha_fwd_v3_args args; + ck_tile::fmha_fwd_v3_args args{}; args.data_type = problem.data_type; args.batch = problem.batch; @@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) : problem.seqlen_q * problem.hdim; args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim; + // Optional cumulative seqlen overrides (exclude PAD) + const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1; + const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1; + + auto make_effective_vec = [&](const std::vector& opt_vec, ck_tile::index_t fallback) { + std::vector eff; + if(!opt_vec.empty() && opt_vec[0] != -1) + { + eff.assign(opt_vec.begin(), opt_vec.end()); + if(eff.size() < static_cast(problem.batch)) + { + eff.resize(problem.batch, eff.back()); + } + } + else + { + eff.assign(problem.batch, fallback); + } + return eff; + }; + + const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q); + const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k); + + // Calculate cumulative sums for kernel arguments if varlen is used + std::vector cuq_cum, cukv_cum; + auto calculate_cumulative = [&](const std::vector& per_batch_vec, + std::vector& cum_vec) { + cum_vec.resize(per_batch_vec.size() + 1); + cum_vec[0] = 0; + for(std::size_t i = 0; i < per_batch_vec.size(); ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + }; + + if(has_varlen_q) + { + calculate_cumulative(eff_q_vec, cuq_cum); + } + if(has_varlen_k) + { + calculate_cumulative(eff_kv_vec, cukv_cum); + } + + ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0); + ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0); + cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr); + cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr); + args.cu_seqlen_q_ptr = + !cuq_cum.empty() ? reinterpret_cast(cuq_buf.GetDeviceBuffer()) + : nullptr; + args.cu_seqlen_kv_ptr = + !cukv_cum.empty() ? reinterpret_cast(cukv_buf.GetDeviceBuffer()) + : nullptr; + ck_tile::stream_config stream_config{nullptr, true, /*log_level=*/0, @@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) o_ref = o_ref.transpose({0, 2, 1, 3}); } - host::fmha_fwd(q, - k, - v, - problem.mask, - o_ref, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::identity{}, - ck_tile::scales{problem.softmax_scale}); + // If variable lengths are provided, compute per-batch references + // with the effective lengths; else compute a single full reference. + if(has_varlen_q || has_varlen_k) + { + // Variable-length aware verification: zero-fill padded region and only compute valid part. + o_ref.SetZero(); + + for(int b = 0; b < problem.batch; ++b) + { + const ck_tile::index_t seqlen_q_eff = eff_q_vec[b]; + const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b]; + + if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0) + continue; + + // Slice current batch from inputs (bshd) and build single-batch tensors + ck_tile::HostTensor q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::HostTensor k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); + ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + + // Copy effective region + q_b.ForEach([&](auto& self, auto idx) { + // idx: [0, s, h, d] + self(idx) = q(b, idx[1], idx[2], idx[3]); + }); + k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); }); + v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); }); + + // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors) + host::fmha_fwd(q_b, + k_b, + v_b, + problem.mask, + o_b, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + + // Scatter into o_ref's bshd descriptor memory + for(int s = 0; s < seqlen_q_eff; ++s) + { + for(int h = 0; h < problem.nhead_q; ++h) + { + for(int d = 0; d < problem.hdim; ++d) + { + o_ref(b, s, h, d) = o_b(0, s, h, d); + } + } + } + } + } + else + { + // No varlen override: compute the full reference once + host::fmha_fwd(q, + k, + v, + problem.mask, + o_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales{problem.softmax_scale}); + } ck_tile::HostTensor o(problem.get_output_shape()); o_buf.FromDevice(o.data()); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index f1f8eee5e4..6cd1cd94fa 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -15,6 +15,10 @@ #include #include +struct FmhaBwdFp32 +{ +}; + struct FmhaBwdFp16 { }; @@ -26,6 +30,26 @@ struct FmhaBwdBf16 template struct FmhaBwdTypeConfig; +template <> +struct FmhaBwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using GemmDataType = float; + using BiasDataType = float; + using LSEDataType = float; + using AccDataType = float; // data type for gemm accumulation + using DDataType = float; + using RandValOutputDataType = uint8_t; + using ODataType = float; + using OGradDataType = float; + using QGradDataType = float; + using KGradDataType = float; + using VGradDataType = float; + using BiasGradDataType = float; +}; + template <> struct FmhaBwdTypeConfig { @@ -368,8 +392,8 @@ template +auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/) +{ + double rtol = 1e-4; + double atol = 1e-4; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v) { @@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -776,7 +786,7 @@ bwd_result fmha_bwd_run(mode_enum mode, // non-deterministic kernels use atomic add to write dq // Some block may be skipped with causal mask and dq are not set to zeros // In these cases thus we need to zero out it first - if(!deterministic || mask.type == mask_enum::no_mask) + if(!deterministic || mask.type != mask_enum::no_mask) dq_acc_buf.SetZero(); ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1}; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index c41e48e6aa..761def6d6a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,10 @@ #include #include +struct FmhaFwdFp32 +{ +}; + struct FmhaFwdFp16 { }; @@ -48,6 +52,22 @@ struct FmhaFwdFp8Fp32 template struct FmhaFwdTypeConfig; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = float; + using KDataType = float; + using VDataType = float; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = float; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + template <> struct FmhaFwdTypeConfig { @@ -162,11 +182,20 @@ struct fmha_fwd_args void* lse_ptr; void* o_ptr; + // Optional cumulative sequence length arrays + // Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD) + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] + const void* seqstart_q_ptr; const void* seqstart_k_ptr; const void* seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr + // Group mode: seqstart_padded_* provide physical starts including PAD (optional) + const void* seqstart_padded_q_ptr = nullptr; // [batch+1] + const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -554,7 +583,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.min_seqlen_q, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.seqstart_padded_q_ptr, + args.seqstart_padded_k_ptr); } else { // create batch mode kernel arguments @@ -600,7 +631,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.mask_type, args.p_drop, args.s_randval, - args.drop_seed_offset); + args.drop_seed_offset, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); } }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 43f484fe14..0703af71e3 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/) return ck_tile::make_tuple(rtol, atol); } +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + template <> auto get_elimit(std::string /*init_method*/) { @@ -151,7 +159,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t seqlen_knew, + std::vector seqlen_qpads, std::vector seqlen_kpads, + std::vector q_eff_lens_per_batch, + std::vector kv_eff_lens_per_batch, ck_tile::index_t rotary_dim, bool i_perm, bool o_perm, @@ -177,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::optional json = std::nullopt) { const std::string data_type = []() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) + return "fp32"; + else if constexpr(std::is_same_v) return "fp16"; else if constexpr(std::is_same_v) return "bf16"; @@ -299,6 +312,24 @@ fwd_result fmha_fwd_run(mode_enum mode, #endif const bool use_kvcache = (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size); + // Reject unsupported padding usage in special pipelines (appendkv / splitkv / pagedkv) + const bool has_group_padding = + (mode == mode_enum::group && (!seqlen_qpads.empty() && seqlen_qpads[0] != -1)) || + (mode == mode_enum::group && (seqlen_kpads[0] >= 0)); + const bool has_batch_efflens = (mode == mode_enum::batch && (!q_eff_lens_per_batch.empty() || + !kv_eff_lens_per_batch.empty())); + const bool using_appendkv = (0 < seqlen_knew || 0 < rotary_dim); + const bool using_pagedkv = (0 < page_block_size); + const bool using_splitkv = (num_splits > 1) || use_cache_batch_idx; + if((using_appendkv || using_pagedkv || using_splitkv) && + (has_group_padding || has_batch_efflens)) + { + std::cerr << "Padding (physical or effective lengths) is not supported with " + "appendkv/splitkv/pagedkv pipelines" + << std::endl; + return fwd_result::invalid_args; + } + std::tie(seqlen_qs, seqlen_ks, seqlen_kpads) = generate_missing_seqlens(mode, batch, @@ -362,6 +393,44 @@ fwd_result fmha_fwd_run(mode_enum mode, const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); + // Optional padded Q seqstarts (group-mode only) + std::vector seqstart_q_with_padding_host; + if(mode == mode_enum::group && !seqlen_qpads.empty() && seqlen_qpads[0] != -1) + { + if(seqlen_qpads.size() < static_cast(batch)) + { + seqlen_qpads.resize(batch, seqlen_qpads.back()); + } + if(seqlen_qpads.size() == static_cast(batch)) + { + seqstart_q_with_padding_host = to_seqstarts( + ck_tile::span(seqlen_qpads.data(), seqlen_qpads.size())); + } + } + + // Optional batch-mode cumulative seqlen overrides + std::vector cuq_cum, cukv_cum; + if(mode == mode_enum::batch) + { + auto calculate_cumulative = [&](std::vector& per_batch_vec, + std::vector& cum_vec) { + if(!per_batch_vec.empty() && per_batch_vec[0] != -1) + { + if(per_batch_vec.size() < static_cast(batch)) + { + per_batch_vec.resize(batch, per_batch_vec.back()); + } + cum_vec.resize(batch + 1); + cum_vec[0] = 0; + for(int i = 0; i < batch; ++i) + cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; + } + }; + + calculate_cumulative(q_eff_lens_per_batch, cuq_cum); + calculate_cumulative(kv_eff_lens_per_batch, cukv_cum); + } + using TypeConfig = FmhaFwdTypeConfig; using QDataType = typename TypeConfig::QDataType; @@ -445,8 +514,15 @@ fwd_result fmha_fwd_run(mode_enum mode, // host memory for storing all the tensor elements const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); - const ck_tile::index_t shape_seqlen_q = + // logical(unpadded) total seqlen_q for group; batch uses fixed seqlen + const ck_tile::index_t shape_seqlen_q_lse = (mode == mode_enum::batch ? seqlen_qs[0] : seqstart_q_host.back()); + // physical(padded) total seqlen_q for group when s_qpad is provided; else use logical + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch + ? seqlen_qs[0] + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host.back() + : seqstart_q_with_padding_host.back())); const ck_tile::index_t shape_seqlen_k = (mode == mode_enum::batch ? seqlen_ks[0] : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() @@ -504,7 +580,7 @@ fwd_result fmha_fwd_run(mode_enum mode, // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] ck_tile::HostTensor lse_host( - lse ? std::array{shape_batch, nhead, shape_seqlen_q} + lse ? std::array{shape_batch, nhead, shape_seqlen_q_lse} : std::array{1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -602,6 +678,16 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.size() * sizeof(int32_t) @@ -693,8 +779,14 @@ fwd_result fmha_fwd_run(mode_enum mode, vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); - seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data() - : seqstart_k_with_padding_host.data()); + // Keep logical starts in seqstart_k; pass padded K via separate pointer + seqstart_k.ToDevice(seqstart_k_host.data()); + seqstart_q_padded_buf.ToDevice( + seqstart_q_with_padding_host.empty() ? nullptr : seqstart_q_with_padding_host.data()); + seqstart_k_padded_buf.ToDevice(seqlen_kpads[0] < 0 ? nullptr + : seqstart_k_with_padding_host.data()); + cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); + cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0] ? seqlen_ks.data() : nullptr); @@ -747,6 +839,54 @@ fwd_result fmha_fwd_run(mode_enum mode, std::cout << ", cache_batch_idx:" << use_cache_batch_idx; } #endif + // Padding / effective length diagnostic logging + auto print_vec = [&](const char* label, const std::vector& v) { + if(v.empty()) + return; + std::cout << ", " << label << ":["; + for(std::size_t i = 0; i < v.size(); ++i) + { + if(i) + std::cout << ","; + std::cout << v[i]; + } + std::cout << "]"; + }; + + if(has_group_padding) + { + bool has_qpad = !seqstart_q_with_padding_host.empty(); + bool has_kpad = (seqlen_kpads[0] >= 0); + if(has_qpad) + { + print_vec("q_logical", seqlen_qs); + print_vec("q_padded", seqlen_qpads); + } + if(has_kpad) + { + print_vec("k_logical", seqlen_ks); + print_vec("k_padded", seqlen_kpads); + } + } + else if(has_batch_efflens) + { + // derive effective lengths from cumulative arrays if present + if(!cuq_cum.empty()) + { + std::vector eff_q(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_q[b_i] = static_cast(cuq_cum[b_i + 1] - cuq_cum[b_i]); + print_vec("q_eff", eff_q); + } + if(!cukv_cum.empty()) + { + std::vector eff_kv(batch); + for(int b_i = 0; b_i < batch; ++b_i) + eff_kv[b_i] = static_cast(cukv_cum[b_i + 1] - cukv_cum[b_i]); + print_vec("kv_eff", eff_kv); + } + } + std::cout << std::flush; const auto init_traits = [&](auto& traits) { @@ -830,8 +970,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q_lse; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); // setup batch_stride_* arguments @@ -846,8 +986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_vnew = (nhead_k * hdim_v * seqlen_knew); const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * max_seqlen_k); const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q); - const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q_lse); + const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q_lse); const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); @@ -961,6 +1101,29 @@ fwd_result fmha_fwd_run(mode_enum mode, { args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); } + + // Group-mode: optional physical padded starts for Q/K + if(mode == mode_enum::group) + { + args.seqstart_padded_q_ptr = (seqstart_q_with_padding_host.empty() + ? nullptr + : seqstart_q_padded_buf.GetDeviceBuffer()); + args.seqstart_padded_k_ptr = + (seqlen_kpads[0] < 0 ? nullptr : seqstart_k_padded_buf.GetDeviceBuffer()); + } + + // Batch-mode: optional cumulative effective seqlen overrides + if(mode == mode_enum::batch) + { + args.cu_seqlen_q_ptr = cuq_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_q_buf.GetDeviceBuffer()); + args.cu_seqlen_kv_ptr = cukv_cum.empty() + ? nullptr + : reinterpret_cast( + cu_seqlen_kv_buf.GetDeviceBuffer()); + } } else if constexpr(std::is_same_v>) { @@ -1000,7 +1163,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } }; - const float appendkv_ave_time = [&] { + auto run_appendkv = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_APPENDKV_API if(need_append_kvcache) { @@ -1010,18 +1173,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_appendkv_args fwd_appendkv_args; init_args(fwd_appendkv_args); - return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, stream_config); + return fmha_fwd_appendkv(fwd_appendkv_traits, fwd_appendkv_args, sc); } #endif return 0.0f; - }(); + }; + const float appendkv_ave_time = run_appendkv(stream_config); if(appendkv_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; return fwd_result::no_instance; } - const float fwd_ave_time = [&] { + auto run_fwd = [&](const ck_tile::stream_config& sc) { #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -1031,8 +1195,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_pagedkv_args fmha_pagedkv_args; init_args(fmha_pagedkv_args); - const float ave_time = - fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config); + const float ave_time = fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, sc); #if CK_TILE_FMHA_FWD_SPLITKV_API // If there is no instance for these args, fallback to fmha_fwd_splitkv if(ave_time >= 0.0f) @@ -1051,7 +1214,7 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_splitkv_args fmha_splitkv_args; init_args(fmha_splitkv_args); - return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, stream_config); + return fmha_fwd_splitkv(fmha_splitkv_traits, fmha_splitkv_args, sc); } #endif // CK_TILE_FMHA_FWD_SPLITKV_API fmha_fwd_traits fmha_traits; @@ -1060,8 +1223,9 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); - return fmha_fwd(fmha_traits, fmha_args, stream_config); - }(); + return fmha_fwd(fmha_traits, fmha_args, sc); + }; + const float fwd_ave_time = run_fwd(stream_config); if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl; @@ -1135,6 +1299,17 @@ fwd_result fmha_fwd_run(mode_enum mode, } else { +#if CK_TILE_FMHA_FWD_APPENDKV_API + // When rotary embedding is used, the appendkv kernel modifies the q tensor (multiple times + // when time_kernel_ is set). We need to reset the q buffer and rerun all kernels. + if(0 < rotary_dim && stream_config.time_kernel_) + { + const ck_tile::stream_config stream_config2{stream_config.stream_id_, false, 0}; + q_buf.ToDevice(q_host.data()); + run_appendkv(stream_config2); + run_fwd(stream_config2); + } +#endif o_buf.FromDevice(o_host.data()); lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); @@ -1167,15 +1342,29 @@ fwd_result fmha_fwd_run(mode_enum mode, for(ck_tile::index_t wb = 0; wb < batch; ++wb) { - const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; - const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + if(mode == mode_enum::batch) + { + if(!cuq_cum.empty()) + { + real_seqlen_q = cuq_cum[wb + 1] - cuq_cum[wb]; + } + if(!cukv_cum.empty()) + { + real_seqlen_k = cukv_cum[wb + 1] - cukv_cum[wb]; + } + } // adjust matrix index according to the mode const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0); const ck_tile::index_t cache_b_idx = (use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx); const ck_tile::index_t query_offset = - (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + (mode == mode_enum::batch + ? 0 + : (seqstart_q_with_padding_host.empty() ? seqstart_q_host[wb] + : seqstart_q_with_padding_host[wb])); const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 @@ -1538,8 +1727,10 @@ fwd_result fmha_fwd_run(mode_enum mode, if(lse) { ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + const ck_tile::index_t query_offset_lse = + (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); lse_host_result.ForEach([&](auto& self, auto idx) { - self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset); + self(idx) = lse_host(b_idx, idx[0], idx[1] + query_offset_lse); }); cur_pass = ck_tile::check_err(lse_host_result, diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 10cb5149a4..4bd1d1a367 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -56,6 +56,11 @@ struct fmha_fwd_v3_args index_t stride_o; index_t nhead_stride_o; index_t batch_stride_o; + + // Optional batch-mode cumulative seqlen overrides (exclude PAD) + // If provided, they override per-batch effective lengths to skip tail padding. + const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1] + const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1] }; std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index e0fbad39a5..194675f962 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -158,7 +158,9 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.window_size_left, args.window_size_right, args.mask_type, - remap_opt); + remap_opt, + args.cu_seqlen_q_ptr, + args.cu_seqlen_kv_ptr); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd.sh b/example/ck_tile/01_fmha/script/benchmark_fwd.sh index 88c16cceb6..31ad800039 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd.sh @@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn done done done + +#Padding Benchmarks: batch mode (baseline vs low/med/high pad) +prec="fp16" +base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no pad) +$EXE $base_batch_args + +# low pad (≈90–95% effective) +$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + +# Padding Benchmarks: group mode (baseline vs low/med/high physical pad) +seqlens_q="1024,768,512,256" +seqlens_k="1024,768,512,256" +base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID" + +# baseline (no physical pad) +$EXE $base_group_args + +# low physical pad +$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + +# medium physical pad +$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + +# high physical pad +$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index b847e85398..a3f7d68eb3 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -23,3 +23,20 @@ done done done done + +# Padding benchmark comparisons for v3 (batch mode only) +# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ==== +prec="fp16" +base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID" + +# baseline (no pad) +$EXE $base_v3_args + +# low pad (≈90–95% effective) +$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + +# medium pad (≈60–75% effective) +$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + +# high pad (≈30–40% effective) +$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index e7babd2744..5c2a5a4b3d 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -34,15 +34,15 @@ function print_log_header(){ } #run verification tests -example/ck_tile/01_fmha/script/smoke_test_fwd.sh -example/ck_tile/01_fmha/script/smoke_test_bwd.sh +time example/ck_tile/01_fmha/script/smoke_test_fwd.sh +time example/ck_tile/01_fmha/script/smoke_test_bwd.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" print_log_header $fmha_fwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log +time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log" print_log_header $fmha_bwd_log $env_type $branch $host_name -example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log +time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 3b59505ff0..cd51dde2d4 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -6,7 +6,7 @@ SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) EXE_NAME=tile_example_fmha_bwd EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 -GPU_arch=$GPU_arch +GPU_arch=${GPU_arch:-""} if [ -z "$GPU_arch" ] ; then GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') fi @@ -31,7 +31,17 @@ run_exe() { set -ex } +test_h_s_mask() { + run_exe -b=1 -h=4 -h_k=2 -s=259 $@ + run_exe -b=2 -h=2 -s=516 -s_k=253 $@ + run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@ + run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@ + run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@ + run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@ +} + set -x +# main tests for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 256 ; do @@ -40,21 +50,21 @@ for bias in "n" "a" ; do for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +done +done +done +done +done +done +done +done -run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS - -done -done -done -done -done -done -done +# additional cases +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS +test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done set +x diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index afd0c728c6..fca6b8d0cd 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -137,9 +137,118 @@ run_fp16_appendkv_tests() { done ; done ; done } +run_padding_smoke_tests() { + # Padding-only smoke tests for batch/group mode using COMMON_ARGS + local prec="fp16" + + # Batch mode: padding via effective lengths (exclude PAD) + # Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches + local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low pad (≈90–95% effective) + $EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896 + # medium pad (≈60–75% effective) + $EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640 + # high pad (≈30–40% effective) + $EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320 + + # Group mode: padding via physical stride along seqlen + local seqlens_q="1024,768,512,256" + local seqlens_k="1024,768,512,256" + local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS" + # low physical pad + $EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320 + # medium physical pad + $EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384 + # high physical pad + $EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512 +} + +run_padding_basic_boundary_tests() { + # Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh) + local prec + local perm + + # Group mode: Q&K padded with per-batch different strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \ + -s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # slightly larger, uneven padding strides + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \ + -s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # only K padded; Q unpadded + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \ + -s=55 -s_k=256 -s_kpad=272,260 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # use cu_seqlen overrides to skip tail PAD + for prec in fp16 bf16 ; do + for perm in 0 1 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \ + -q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \ + -bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \ + -q_eff_lens=55,60 -kv_eff_lens=200,256 \ + -bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \ + -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS + done + done + + # no padding (equal), mixed Q/KV, all len=1 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + + $EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \ + -q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # highly variable logical lengths + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \ + -s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \ + -bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS + done + + # GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16 + for prec in fp16 bf16 ; do + $EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \ + -s=256,129 -s_k=256,129 -s_kpad=256 \ + -bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \ + -kname=$KNAME $COMMON_ARGS + done +} + set -x run_fp16_bf16_tests +run_padding_smoke_tests +run_padding_basic_boundary_tests run_fp8_tests run_fp8bf16_tests run_fp8fp32_tests diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index b7512b2999..5f589db8d0 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index f0a71670ad..9ece1638b5 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -231,7 +231,7 @@ struct SplitKTwoStageInvoker preprocess = clear_gemm_output; } - return ck_tile::launch_kernel_time_mask( + ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel( @@ -245,20 +245,21 @@ struct SplitKTwoStageInvoker ck_tile::make_tuple(args.N, 1), // Output Stride input_tensors, static_cast(c_ptr))); + + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index f200332588..dd13ed7bba 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { // For workspace mode, always use SET operation since each K-split writes to separate memory - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } /** diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index af740dc0f1..25ed6f9ac5 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,8 +16,9 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 +#define CK_TILE_PIPELINE_COMPUTE_V6 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7 template constexpr ck_tile::index_t get_k_warp_tile() @@ -251,9 +252,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; +}; + +template +struct GemmConfigComputeV6 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::index_t NumWaveGroups = 1; }; template @@ -484,6 +505,15 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; +}; + template <> struct PipelineTypeTraits { diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index dca4ef8f40..090a3a9f65 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -75,6 +75,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) ck_tile::bf8_t, ck_tile::half_t>(a_layout, b_layout, arg_parser); } + else if(data_type == "int4") + { + return run_gemm_example_prec_type, + ck_tile::fp8_t, + ck_tile::pk_int4_t, + ck_tile::half_t>(a_layout, b_layout, arg_parser); + } else { throw std::runtime_error("Unsupported data type for this operation !!!"); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp index d737a0f864..023b0336fe 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp @@ -194,10 +194,7 @@ struct WeightPreshuffleInvoker } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + throw std::runtime_error("split-k is not supported yet!"); } }; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index de17628ab3..9f57609523 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -309,16 +309,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, if(init_method == 0) { - if constexpr(preshuffle) - { - ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); - ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); - } + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); } else if(init_method == 1) { @@ -362,6 +354,10 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser, } }(); // shuffled buffer B for device implementation + if constexpr(std::is_same_v) + { + ck_tile::permute_vectors_i4x4_b(b_shuffle_host); + } b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); } else diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 19855c7f72..d0fd69b1e2 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -174,24 +174,25 @@ struct UniversalInvoker preprocess = clear_gemm_output; } - return ck_tile::launch_kernel_time_mask( + ave_time = ck_tile::launch_kernel_time_mask( s, preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, tail_number_, MemoryOpSet{}); + return Run(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } }; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 0e948322a2..75d7abd0ad 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_ using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; @@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] - } + } } - + total_blob = list() for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: - hs = current_trait_dict[hs_key] + hs = current_trait_dict[hs_key] current_n = hs_key for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 6e2664e9ba..8518b5ddc7 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -70,16 +70,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); + if(n > 8192) + { + use_model_sensitive_rmsnorm = 0; + } + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; @@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const int N = acc_.mDesc.get_lengths()[1]; for(int n_ = 0; n_ < N; ++n_) { - o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); } dquant_functor(m_, o_, acc_); @@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - default_and_dquant_functor); + default_and_dquant_functor, + use_model_sensitive_rmsnorm); } else { @@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - dquant_functor); + dquant_functor, + use_model_sensitive_rmsnorm); } } else @@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser) YDataType, InvRmsDataType, ck_tile::null_type>( - x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); + x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_null, + epsilon, + ck_tile::reference_rmsnorm2d_default_epilogue{}, + use_model_sensitive_rmsnorm); } y_buf.FromDevice(y_host_dev.data()); @@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } + if constexpr(SaveUnquant) + { + unquant_y_buf.FromDevice(unquant_y_host_dev.data()); + } + auto [rtol, atol] = get_elimit(); if(x_stride == n) { diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 1c79dafadd..3a0f7dbb66 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,49 +1,85 @@ -#!/bin/sh +#!/bin/bash + EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\ - "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 -done -done -done +total=0 +valid=0 + +run_case() { + cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7" + echo "[CMD] $cmd" + output=$($cmd 2>&1) + echo "$output" + if echo "$output" | grep -q "valid:y"; then + valid=$((valid + 1)) + fi + total=$((total + 1)) +} + +fquant_list=( + "" + "-fquant=1 -prec_o=int8" + "-fquant=2 -prec_o=int8" + "-fquant=1 -prec_o=fp8" + "-fquant=2 -prec_o=fp8" + "-fquant=1 -prec_o=int8 -save_unquant=1" + "-fquant=2 -prec_o=int8 -save_unquant=1" + "-fquant=1 -prec_o=fp8 -save_unquant=1" + "-fquant=2 -prec_o=fp8 -save_unquant=1" +) + +m_n_list=( + "99 13" "17 16" "1 100" "4 128" "80 127" + "7 599" "19 512" "11 510" "91 636" + "31 1024" "8 1501" "3 1826" "5 2040" + "7 2734" "1 3182" "9 4096" "3 8192" +) + +### Add special stride test ### +m_n_stride_list=( + "22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256" + "33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000" + "171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818" + "12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800" + "100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812" + "64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004" +) + +for fquant in "${fquant_list[@]}"; do + for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + for pair in "${m_n_list[@]}"; do + m=$(echo $pair | cut -d ' ' -f1) + n=$(echo $pair | cut -d ' ' -f2) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "" + done + + ### Running tests with stride ### + for triple in "${m_n_stride_list[@]}"; do + m=$(echo $triple | cut -d ' ' -f1) + n=$(echo $triple | cut -d ' ' -f2) + stride_args=$(echo $triple | cut -d ' ' -f3-) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args" + done + done + done + done done -# The following cases uses two pass pipeline which doesn't support quant epilogue. -for fquant in "" -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 -done -done -done +# Special two-pass only +for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + run_case "$pr_i" "$fadd" "$s" "" "1" "10547" "" + done + done done + +# Summary +echo "==============================" +echo "Total cases: $total" +echo "Valid cases: $valid" +accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}") +echo "Accuracy: $accuracy%" +echo "==============================" diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index d614b8462a..00c6be8f10 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - constexpr bool local_token = local_token_; \ - using ms_problem = ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ @@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ @@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = \ ck_tile::launch_kernel(s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ @@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi float ave_time = ck_tile::launch_kernel( \ s, \ maybe_clear_workspace, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ @@ -369,69 +427,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co } }; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + if(a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_1(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - maybe_clear_workspace, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_1(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 441aa84edf..5edb74f52f 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -198,22 +198,40 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return -1; } -#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - constexpr bool local_token = local_token_; \ - using ms_problem = ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + constexpr bool local_token = local_token_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() #define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \ @@ -290,6 +308,46 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() +#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ + return ave_time; \ + } \ + } \ + else \ + { \ + if(is_local_token) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = ck_tile::launch_kernel( \ + s, \ + MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ + return ave_time; \ + } \ + } + #define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ if(t.local_expert_masking) \ { \ @@ -297,7 +355,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \ return ave_time; \ @@ -306,7 +364,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \ return ave_time; \ @@ -318,7 +376,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = \ ck_tile::launch_kernel(s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \ return ave_time; \ @@ -327,7 +385,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til { \ float ave_time = ck_tile::launch_kernel( \ s, \ - MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \ + MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \ MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \ MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \ return ave_time; \ @@ -344,67 +402,156 @@ float fused_moesorting_mp(fused_moesorting_trait t, using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > - ck_tile::get_smem_capacity()) + auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) { + if(t.clear_workspace_inside_api) + { + if(is_local_token) + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1); + k(s_); + } + else + { + auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1); + k(s_); + } + } + }; + + if(a.tokens < 2048) { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { #if MOE_SORTING_SUPPORT_LARGE_EXPERT - if(t.local_expert_masking) - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, true), - MOE_SORTING_MP_1(ms_index_t, 1, true), - MOE_SORTING_MP_2(ms_index_t, 1, true), - MOE_SORTING_MP_3(ms_index_t, 1, true)); - return ave_time; - } - else - { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(ms_index_t, 1, false), - MOE_SORTING_MP_1(ms_index_t, 1, false), - MOE_SORTING_MP_2(ms_index_t, 1, false), - MOE_SORTING_MP_3(ms_index_t, 1, false)); - return ave_time; - } -#else - printf("do not support large expert %d\n", a.num_experts); - return -1; -#endif - } - else - { - ck_tile::index_t mesh_byte_size = - ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); - if(mesh_byte_size == 1) - { - if(a.tokens * a.topk % 4 == 0) + if(t.local_expert_masking) { - MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; } else { - MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) - } - } - else if(mesh_byte_size == 2) - { -#if MOE_SORTING_SUPPORT_LARGE_TOPK - if(a.tokens * a.topk % 4 == 0) - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) - } - else - { - MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + float ave_time = + ck_tile::launch_kernel(s, + MOE_SORTING_MP_0_V2(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; } #else - printf("do not support large topk %d\n", a.topk); + printf("do not support large expert %d\n", a.num_experts); return -1; #endif } else { - MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1) + } + } + } + else + { + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = + ck_tile::launch_kernel(s, + maybe_clear_workspace, + MOE_SORTING_MP_0_V1(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } } diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 4f3b173c55..bbfb2df006 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,10 +1,12 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) +add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 94481fa7b7..09bf3e167a 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -1,140 +1,8 @@ -# Grouped Gemm - -Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation. - ## Quick Tour for New Users The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads. -Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function. - -### Key Arguments -The example takes several arguments including `group_count`, `repeat`, and `warmup`: -- `group_count`: the number of GEMM operations in the group -- `repeat`: the number of times to repeat the kernel for benchmarking -- `warmup`: the number of iterations before the actual kernel run time measure - -```cpp -// Example -const int group_count = arg_parser.get_int("group_count"); -const int repeat = arg_parser.get_int("repeat"); -const int warmup = arg_parser.get_int("warmup"); -``` -In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`. - -```cpp -// Example -std::vector Ms = arg_parser.get_int_vec("Ms"); -std::vector Ns = arg_parser.get_int_vec("Ns"); -std::vector Ks = arg_parser.get_int_vec("Ks"); -std::vector stride_As = arg_parser.get_int_vec("stride_As"); -std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); -std::vector stride_Cs = arg_parser.get_int_vec("stride_Cs"); -``` -Where: -- `Ms` is the M dimension of each GEMM. -- `Ns` is the N dimension of each GEMM. -- `Ks` is the K dimension of each GEMM. -- `stride_As` is the stride values for matrix A. -- `stride_Bs` is the stride values for matrix B. -- `stride_Cs` is the stride values for matrix C. - -### HostTensor and Device Memory Buffers (for CPU and GPU) -Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations. -The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where: -- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation. -- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation. - -#### HostTensor Buffers (for CPU) -In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors: -```cpp -// Example -std::vector> a_m_k_tensors; -std::vector> b_k_n_tensors; -std::vector> c_m_n_tensors; -``` -Where: -- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation. -- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`). -- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`). - -The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`. - -#### Device Memory Buffers (for GPU) -Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation.. -```cpp -// Example -std::vector> a_m_k_dev_buf; -std::vector> b_k_n_dev_buf; -std::vector> c_m_n_dev_buf; -``` -Where: -- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU. -- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU. -- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU. - -## Prepare data -In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor. - -Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order. - -```cpp -// Example, API -template -auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant) { - // code -} -``` - -Where: -- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false). -- `row` is the number of rows in the matrix. -- `col` is the number of columns in the matrix. -- `stride` is the current stride (the distance between consecutive elements in memory). -- `bool_constant` is a tag type that helps in differentiating behavior at compile-time. - -Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout. - -```cpp -// Example for tensor A -ck_tile::HostTensor(f_host_tensor_descriptor(M, K, stride_As[i], a_layout))) -``` - -After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`. - -The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as: -```cpp -struct GroupedGemmHostArgs -{ - const void* a_ptr; - const void* b_ptr; - void* c_ptr; - index_t M; - index_t N; - index_t K; - index_t stride_A; - index_t stride_B; - index_t stride_C; -}; -``` -The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`: -```cpp -// Example, API -template -float invoke_gemm(int n_warmup, - int n_repeat, - int group_count, - const std::vector& args) -``` -`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace: - -```cpp -// Example -ck_tile::DeviceMem gemm_workspace; -gemm_workspace.Realloc(GetWorkspaceSize(args)); -``` - -### Advanced Features: Preshuffle and Persistence +### Preshuffle and Persistence The grouped GEMM examples include two advanced optimization features: @@ -142,28 +10,28 @@ The grouped GEMM examples include two advanced optimization features: Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches. - **Implementation**: Available in `grouped_gemm_preshuffle.cpp` -- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration +- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration - **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts -- **Benefits**: Improved memory efficiency and reduced data movement + #### Persistence Mode Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy. - **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm` - **Usage**: `invoke_gemm` enables persistence -- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes -Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. +#### Multi-D Operations +Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output. -Finally the arguments are passed to group_gemm and the kernel is launched. -```cpp -// API -template -float grouped_gemm(const std::vector& gemm_descs, - const ck_tile::stream_config& s, - void* kargs_ptr) -``` -All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched. +- **Implementation**: Available in `grouped_gemm_multi_d.cpp` +- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result) +- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors +- **Data Types**: Supports fp16 +- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call +- **Build Target**: `make tile_example_grouped_gemm_multi_d -j` + +Multi-D operations supports both persistence and non-persistence modes. +Weight preshuffle supports only on non-persistence mode. ## Build ``` @@ -175,10 +43,13 @@ mkdir build && cd build make tile_example_grouped_gemm -j # The preshuffle example make tile_example_grouped_gemm_preshuffle -j +# The multi-D operations example +make tile_example_grouped_gemm_multi_d -j # The quant grouped gemm fp8 example make tile_example_quant_grouped_gemm -j ``` -This will result in an executable `build/bin/tile_example_grouped_gemm` +Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. + ## example ``` @@ -213,4 +84,4 @@ K[i] = 512 + 384 * i stride_A[i] = K[i] stride_B[i] = K[i] stride_C[i] = N[i] -``` +``` \ No newline at end of file diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 606d98d9e2..f5335c3ec0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -70,99 +70,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template ( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; if(!splitk) { - Run(ck_tile::integral_constant{}); + return ave_time = Run(ck_tile::integral_constant{}); } else { - Run(ck_tile::integral_constant{}); + return ave_time = + Run(ck_tile::integral_constant{}); } - - return ave_time; } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 6493a542ba..10d7befc06 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -9,7 +9,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" #define CK_TILE_PIPELINE_COMPUTE_V3 1 @@ -296,7 +295,7 @@ struct PipelineTypeTraits ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; }; -using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs<>; std::pair create_args(int argc, char* argv[]) { @@ -325,7 +324,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } template diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp new file mode 100644 index 0000000000..98b0428d39 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "grouped_gemm_multi_d.hpp" + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) +{ + + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +template +float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + +#include "run_grouped_gemm_multi_d_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_grouped_gemm_multi_d_example(argc, argv); +#else + return !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv) || + !run_grouped_gemm_multi_d_example(argc, argv); +#endif +} diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp new file mode 100644 index 0000000000..12d70eecb6 --- /dev/null +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -0,0 +1,246 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet + static constexpr bool Persistent = false; // currently persistent == true is not supported yet + static constexpr bool DoubleSmemBuffer = + false; // currently double smem buffer == true is not supported yet +}; + +struct GemmConfigMemory : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; +struct GemmConfigV4 : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma : public GemmConfigBase +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template +struct GemmMultiDTypeConfig; + +template <> +struct GemmMultiDTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using D0DataType = ck_tile::half_t; + using D1DataType = ck_tile::half_t; + using EDataType = ck_tile::half_t; + using DsDataType = ck_tile::tuple; + using AccDataType = float; +}; + +template <> +struct GemmMultiDTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using D0DataType = ck_tile::bf16_t; + using D1DataType = ck_tile::bf16_t; + using EDataType = ck_tile::bf16_t; + using DsDataType = ck_tile::tuple; + using AccDataType = float; +}; + +// Deduce the number of D tensors from the DsDataType tuple size +// All precision configs have the same number of D tensors, so we can use any one +constexpr std::size_t NumDTensor = GemmMultiDTypeConfig::DsDataType::size(); + +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; + +std::pair create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("Ms", "", "M dimensions - empty by default.") + .insert("Ns", "", "N dimensions - empty by default.") + .insert("Ks", "", "K dimensions - empty by default.") + .insert("stride_As", "", "Tensor A strides - it is empty by default.") + .insert("stride_Bs", "", "Tensor B strides - it is empty by default.") + .insert("stride_Ds", "", "Tensor Ds strides - it is empty by default.") + .insert("stride_Es", "", "Tensor E strides - it is empty by default.") + .insert("a_layout", "R", "A tensor data layout - Row by default.") + .insert("b_layout", "C", "B tensor data layout - Row by default.") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default.") + .insert("e_layout", "R", "E tensor data layout - Row by default.") + .insert("validate", "1", "0. No validation, 1. Validation on CPU.") + .insert("prec", "bf16", "data type. fp16/bf16") + .insert("warmup", "10", "number of iterations before benchmark the kernel.") + .insert("repeat", "100", "number of iterations to benchmark the kernel.") + .insert("group_count", "8", "group count.") + .insert("kbatch", "1", "kbatch for SplitK") + .insert("json", "0", "0: No Json, 1: Dump Results in Json format") + .insert("jsonfile", "grouped_gemm.json", "json file name to dump results"); + + bool result = arg_parser.parse(argc, argv); + return std::make_pair(result, arg_parser); +} + +inline std::size_t get_workspace_size(const std::vector& gemm_descs) +{ + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); +} + +template +float grouped_gemm_multi_d(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4ce55e8e72..b9d6a4a1bc 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -76,99 +76,95 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Kernel arguments not supported!"); + } - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; + } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - - return ave_time; - }; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(gemm_descs[0].k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 409bb173a1..64c9dda64a 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -109,23 +109,19 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; + return ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); }; - Run(ck_tile::integral_constant{}); - - return ave_time; + return ave_time = Run(ck_tile::integral_constant{}); } #include "quant_run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc index 658a4dfa62..10d317a2c7 100644 --- a/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/quant_run_grouped_gemm_example.inc @@ -183,7 +183,7 @@ int run_grouped_gemm_example_with_layouts(int argc, if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) { std::cout << "Please check the input data. Default values will be used." << std::endl; - + // Clear existing (invalid) data before adding defaults Ms.clear(); Ns.clear(); @@ -193,7 +193,7 @@ int run_grouped_gemm_example_with_layouts(int argc, stride_Cs.clear(); stride_AQs.clear(); stride_BQs.clear(); - + for(int i = 0; i < group_count; i++) { Ms.push_back(256 + 256 * i); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 026f2bd8f6..f822c7d8a7 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -88,7 +88,7 @@ float invoke_gemm(int n_warmup, // The contents of the memory pointed to by `kargs_ptr` pointer could be // written by e.g. another kernel from earlier stage. - std::vector kargs; + std::vector> kargs; void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) @@ -109,7 +109,7 @@ float invoke_gemm(int n_warmup, const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, kargs.data(), - kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>), hipMemcpyHostToDevice, stream.stream_id_)); ave_time = grouped_gemm_tileloopGetDeviceBuffer(); void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); - gemm_descs.push_back( - {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + gemm_descs.push_back({p_a, + p_b, + {/*ds_ptr*/}, + p_c, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + {/*stride_Ds*/}, + stride_Cs[i]}); } float ave_time = invoke_gemm + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_gemm(int n_warmup, + int n_repeat, + int group_count, + const std::vector& args) +{ + // Workspace memory allocated to hold the gemm descriptions. + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(args)); + + float ave_time = 0; + if constexpr(!GemmConfig::Persistent) + { + ave_time = grouped_gemm_multi_d( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + gemm_workspace.GetDeviceBuffer()); + } + else + { + std::vector> kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR( + hipMemcpyWithStream(kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = + grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + } + return ave_time; +} + +template +int run_grouped_gemm_multi_d_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + + using CDElementWise = MultiplyMultiply; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + auto valid_input_data = [&](int group_count, const auto&... args) { + return !(args.empty() || ...) && group_count == (args.size() == ...); + }; + + const int group_count = arg_parser.get_int("group_count"); + const int repeat = arg_parser.get_int("repeat"); + const int warmup = arg_parser.get_int("warmup"); + const int kbatch = arg_parser.get_int("kbatch"); + bool validate = arg_parser.get_bool("validate"); + + if(kbatch > 1 && validate && warmup + repeat > 1) + { + std::cout << "WARNING: Data validation enabled with SplitK and more than" + << "1 warmup/repeat. Disabling validation." << std::endl; + validate = false; + } + + std::vector Ms = arg_parser.get_int_vec("Ms"); + std::vector Ns = arg_parser.get_int_vec("Ns"); + std::vector Ks = arg_parser.get_int_vec("Ks"); + std::vector stride_As = arg_parser.get_int_vec("stride_As"); + std::vector stride_Bs = arg_parser.get_int_vec("stride_Bs"); + std::vector stride_D0 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_D1 = arg_parser.get_int_vec("stride_Ds"); + std::vector stride_Es = arg_parser.get_int_vec("stride_Es"); + + if(!valid_input_data( + group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_D0, stride_D1, stride_Es)) + { + std::cout << "Please check the input data. Default values will be used." << std::endl; + std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, " + "896, 1280..), stride_As (Ks), stride_Bs (Ks), stride_D0 (Ns), stride_D1 " + "(Ns), stride_Es (Ns)" + << std::endl; + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 384 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_D0.push_back(Ns[i]); + stride_D1.push_back(Ns[i]); + stride_Es.push_back(Ns[i]); + } + } + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> d0_m_n_tensors; + std::vector> d1_m_n_tensors; + std::vector> e_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + d0_m_n_tensors.reserve(group_count); + d1_m_n_tensors.reserve(group_count); + e_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> d0_m_n_dev_buf; + std::vector> d1_m_n_dev_buf; + std::vector> e_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + d0_m_n_dev_buf.reserve(group_count); + d1_m_n_dev_buf.reserve(group_count); + e_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + + stride_D0[i] = ck_tile::get_default_stride(M, N, stride_D0[i], is_row_major(d0_layout)); + stride_D1[i] = ck_tile::get_default_stride(M, N, stride_D1[i], is_row_major(d1_layout)); + + stride_Es[i] = ck_tile::get_default_stride(M, N, stride_Es[i], is_row_major(e_layout)); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); + + d0_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D0[i], is_row_major(d0_layout)))); + d1_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_D1[i], is_row_major(d1_layout)))); + + e_m_n_tensors.push_back(ck_tile::HostTensor( + ck_tile::host_tensor_descriptor(M, N, stride_Es[i], is_row_major(e_layout)))); + + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " d0_m_n: " << d0_m_n_tensors[i].mDesc + << " d1_m_n: " << d1_m_n_tensors[i].mDesc << " e_m_n: " << e_m_n_tensors[i].mDesc + << std::endl; + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique(a_m_k_tensors[i])); + + b_k_n_dev_buf.push_back(std::make_unique(b_k_n_tensors[i])); + + d0_m_n_dev_buf.push_back(std::make_unique(d0_m_n_tensors[i])); + d1_m_n_dev_buf.push_back(std::make_unique(d1_m_n_tensors[i])); + e_m_n_dev_buf.push_back(std::make_unique(e_m_n_tensors[i])); + + e_m_n_dev_buf[i]->SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_e = e_m_n_dev_buf[i]->GetDeviceBuffer(); + + std::array ds_ptr_buf = { + d0_m_n_dev_buf[i]->GetDeviceBuffer(), d1_m_n_dev_buf[i]->GetDeviceBuffer()}; + std::array stridesDs = {stride_D0[i], stride_D1[i]}; + + gemm_descs.push_back({p_a, + p_b, + ds_ptr_buf, + p_e, + kbatch, + M, + N, + K, + stride_As[i], + stride_Bs[i], + stridesDs, + stride_Es[i]}); + } + + float ave_time = invoke_gemm(warmup, repeat, group_count, gemm_descs); + + std::string op_name{"Grouped Gemm Multiple-D"}; + + std::size_t flop = 0, num_btype = 0; + for(int j = 0; j < group_count; ++j) + { + flop += std::size_t(2) * gemm_descs[j].M * gemm_descs[j].N * gemm_descs[j].K; + ck_tile::static_for<0, DsDataType::size(), 1>{}([&](auto i) { + num_btype += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + flop += sizeof(ck_tile::remove_cvref_t>) * + gemm_descs[j].M * gemm_descs[j].N; + }); + + num_btype += sizeof(ADataType) * gemm_descs[j].M * gemm_descs[j].K + + sizeof(BDataType) * gemm_descs[j].K * gemm_descs[j].N + + sizeof(EDataType) * gemm_descs[j].M * gemm_descs[j].N; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + std::vector> e_m_n_host_refs; + e_m_n_host_refs.reserve(group_count); + + // copy e_m_n_tensors result from device to host and initialize host tensors to zero + for(int i = 0; i < group_count; i++) + { + e_m_n_dev_buf[i]->FromDevice(e_m_n_tensors[i].data()); + } + + bool pass{true}; + if(validate) + { + for(int i = 0; i < group_count; ++i) + { + e_m_n_host_refs.push_back(ck_tile::HostTensor( + host_tensor_descriptor(Ms[i], Ns[i], stride_Es[i], is_row_major(e_layout)))); + + e_m_n_host_refs[i].SetZero(); + + ck_tile::reference_gemm_multiple_d( + a_m_k_tensors[i], + b_k_n_tensors[i], + {d0_m_n_tensors[i], d1_m_n_tensors[i]}, + e_m_n_host_refs[i]); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); + + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], 1, max_accumulated_value); + + pass &= + ck_tile::check_err(e_m_n_tensors[i], + e_m_n_host_refs[i], + "Error: Incorrect results! in group [" + std::to_string(i) + "]", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + if(arg_parser.get_int("json") == 1) + { + dump_grouped_gemm_json_results(arg_parser.get_str("jsonfile"), + op_name, + group_count, + pass, + ave_time, + tflops, + gb_per_sec); + } + + return pass; +} + +template +int run_gemm_multi_d_example_prec_type( + std::string a_layout, std::string b_layout, std::string ds_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using Types = GemmMultiDTypeConfig; + + using ADataType = typename Types::ADataType; + using BDataType = typename Types::BDataType; + using D0DataType = typename Types::D0DataType; + using D1DataType = typename Types::D1DataType; + using AccDataType = typename Types::AccDataType; + using EDataType = typename Types::EDataType; + + if(a_layout == "R" && b_layout == "C" && ds_layout == "R") + { + return run_grouped_gemm_multi_d_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for provided tensors!"); + } +} + +template +int run_grouped_gemm_multi_d_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string ds_layout = arg_parser.get_str("ds_layout"); + const std::string data_type = arg_parser.get_str("prec"); + + if(data_type == "fp16") + { + return run_gemm_multi_d_example_prec_type( + a_layout, b_layout, ds_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_gemm_multi_d_example_prec_type( + a_layout, b_layout, ds_layout, argc, argv); + } + else + { + throw std::runtime_error( + "Unsupported data type configuration. Only fp16 and bf16 are supported."); + } +} diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index 280da8d333..3273fac674 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -167,38 +167,38 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c hipGetErrorString(hipMemsetAsync( args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } else { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); + return Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); - return ave_time; + return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template