diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 15903314f9..ccdfb0f6fb 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,8 +1,8 @@ -* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent +* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli # Documentation files -docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz -.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli +.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli # Header directory for Doxygen documentation -library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz +library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli diff --git a/CHANGELOG.md b/CHANGELOG.md index f9da2b3117..a1163f059c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,14 +7,23 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ### Added * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data +* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels. +* Added support GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW, number of instances in instance factory for NGCHW/GKYXC/NGKHW has been reduced). * Added support for GKCYX layout for grouped convolution forward (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added GEMM pipeline for microscaling (MX) data types +* Added support for FP16 2:4 structured sparsity to universal GEMM. +* Added support for Split K for grouped convolution backward data. +* Added logit soft-capping support for fMHA forward kernels. ### Optimized -None + +* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout. (#2166) +* Added Vectorize Transpose optimization for CK Tile (#2131) + ### Fixes diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c1ca789f5..4e12462a41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -202,7 +202,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") set(CK_USE_XDL "ON") endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") - message("Enabling FP8 gemms on native architectures") + message("Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() @@ -211,6 +211,11 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1 add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") + message("Enabling WMMA FP8 gemms on native architectures") + add_definitions(-DCK_USE_WMMA_FP8) + set(CK_USE_WMMA_FP8 "ON") +endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_OCP_FP8) set(CK_USE_OCP_FP8 "ON") @@ -610,6 +615,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS) PACKAGE_NAME examples ) add_subdirectory(example) + add_subdirectory(tile_engine) if(BUILD_TESTING) add_subdirectory(test) endif() diff --git a/Dockerfile b/Dockerfile index 17800d92d5..c629bd034c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ -FROM ubuntu:22.04 +FROM ubuntu:24.04 ARG DEBIAN_FRONTEND=noninteractive -ARG ROCMVERSION=6.3 +ARG ROCMVERSION=6.4 ARG compiler_version="" ARG compiler_commit="" ARG CK_SCCACHE="" @@ -9,19 +9,18 @@ ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn # Add rocm repository RUN set -xe && \ - useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins && \ apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl && \ curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.4" ]; then \ - sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/focal/amdgpu-install_6.3.60300-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.3.60300-1_all.deb && \ +RUN if [ "$ROCMVERSION" != "6.5" ]; then \ + sh -c "wget https://repo.radeon.com/amdgpu-install/$ROCMVERSION/ubuntu/jammy/amdgpu-install_6.4.60400-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.4.60400-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO jammy main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu jammy main > /etc/apt/sources.list.d/amdgpu.list'; \ fi -RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" && \ +RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu jammy main universe | tee -a /etc/apt/sources.list" && \ amdgpu-install -y --usecase=rocm --no-dkms ## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined @@ -44,17 +43,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- iputils-ping \ jq \ libelf-dev \ - libncurses5-dev \ libnuma-dev \ libpthread-stubs0-dev \ llvm-amdgpu \ mpich \ net-tools \ pkg-config \ - python \ - python3 \ - python3-dev \ - python3-pip \ + python3-full \ redis \ rocm-llvm-dev \ sshpass \ @@ -74,10 +69,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- # Remove unnecessary rocm components that take a lot of space apt-get remove -y rocblas rocfft rocsparse composablekernel-dev hipblaslt -# Update the cmake to version 3.27.5 -RUN pip install --upgrade cmake==3.27.5 && \ #Install latest ccache - git clone https://github.com/ccache/ccache.git && \ +RUN git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install && \ #Install ninja build tracing tools cd / && \ @@ -98,8 +91,7 @@ RUN pip install --upgrade cmake==3.27.5 && \ wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results - pip3 install --upgrade pip && \ - pip3 install --upgrade pytest sqlalchemy==2.0.36 pymysql pandas==2.2.3 setuptools-rust setuptools>=75 sshtunnel==0.4.0 && \ + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/Dockerfile.compiler b/Dockerfile.compiler index a22103b96b..7534910681 100644 --- a/Dockerfile.compiler +++ b/Dockerfile.compiler @@ -1,4 +1,4 @@ -ARG BASE_DOCKER="rocm/composable_kernel:ck_ub22.04_rocm6.3" +ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm6.4" FROM $BASE_DOCKER ARG compiler_version="" ARG compiler_commit="" diff --git a/Jenkinsfile b/Jenkinsfile index 86cac3c485..68e0fa1246 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,11 +39,11 @@ def getBaseDockerImageName(){ } else{ def ROCM_numeric = "${params.ROCMVERSION}" as float - if ( ROCM_numeric < 6.4 ){ - img = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm${params.ROCMVERSION}" + if ( ROCM_numeric < 6.5 ){ + img = "${env.CK_DOCKERHUB}:ck_ub24.04_rocm${params.ROCMVERSION}" } else{ - img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm${params.ROCMVERSION}" + img = "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm${params.ROCMVERSION}" } } return img @@ -76,6 +76,7 @@ def check_host() { if ("${env.CK_SCCACHE}" != "null"){ def SCCACHE_SERVER="${env.CK_SCCACHE.split(':')[0]}" echo "sccache server: ${SCCACHE_SERVER}" + sh "chmod +w -R ${env.WORKSPACE}" sh '''ping -c 1 -p 6379 "${SCCACHE_SERVER}" | echo $? > tmp.txt''' def output = readFile(file: "tmp.txt") echo "tmp.txt contents: \$output" @@ -115,7 +116,7 @@ def getDockerImage(Map conf=[:]){ def retimage try { - echo "Pulling down image: ${image}" + echo "Pulling image: ${image}" retimage = docker.image("${image}") withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { retimage.pull() @@ -291,11 +292,6 @@ def cmake_build(Map conf=[:]){ setup_cmd = conf.get("setup_cmd", """${cmake_envs} cmake -G Ninja ${setup_args} -DCMAKE_CXX_FLAGS=" -O3 -ftime-trace " .. """) build_cmd = conf.get("build_cmd", "${build_envs} ninja -j${nt} ${config_targets}") } - else if (setup_args.contains("gfx908;gfx90a;gfx942")){ - //limit the number of build threads when building for multiple gfx9 targets - setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") - build_cmd = conf.get("build_cmd", "${build_envs} make -j32 ${config_targets}") - } else{ setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") build_cmd = conf.get("build_cmd", "${build_envs} make -j${nt} ${config_targets}") @@ -331,14 +327,16 @@ def cmake_build(Map conf=[:]){ } } else{ - // run unit tests - sh "make check" + // run unit tests unless building library for all targets + if (!params.BUILD_INSTANCES_ONLY){ + sh "make check" + } } } } - // Only archive from master or develop - if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) { + // Only archive from develop + if (package_build == true && env.BRANCH_NAME == "develop") { archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true } //check the node gpu architecture @@ -364,6 +362,20 @@ def cmake_build(Map conf=[:]){ echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." } } + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + archiveArtifacts "perf_transpose_*.log" + if (arch_type == 1){ + stash includes: "perf_transpose_**_gfx90a.log", name: "perf_transpose_log_gfx90a" + } + else if (arch_type == 2){ + stash includes: "perf_transpose_**_gfx942.log", name: "perf_transpose_log_gfx942" + } + } + catch(Exception err){ + echo "could not locate the requested artifacts: ${err.getMessage()}. will skip the stashing." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ archiveArtifacts "perf_tile_gemm_**.log" @@ -398,7 +410,7 @@ def buildHipClangJob(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -467,7 +479,7 @@ def Build_CK(Map conf=[:]){ def prefixpath = conf.get("prefixpath", "/opt/rocm") // Jenkins is complaining about the render group - def dockerOpts="-u root --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } @@ -517,34 +529,40 @@ def Build_CK(Map conf=[:]){ else if ( runShell('grep -n "gfx942" rocminfo.log') ) { arch_type = 2 } - else if ( runShell('grep -n "gfx1030" rocminfo.log') ) { + else if ( runShell('grep -n "gfx10" rocminfo.log') ) { arch_type = 3 } - else if ( runShell('grep -n "gfx1101" rocminfo.log') ) { + else if ( runShell('grep -n "gfx11" rocminfo.log') ) { arch_type = 4 } - else if ( runShell('grep -n "gfx1201" rocminfo.log') ) { + else if ( runShell('grep -n "gfx12" rocminfo.log') ) { arch_type = 5 } else if ( runShell('grep -n "gfx908" rocminfo.log') ) { arch_type = 6 } cmake_build(conf) - if ( !params.BUILD_LEGACY_OS && arch_type == 1 ){ + if ( params.RUN_INDUCTOR_TESTS && !params.BUILD_LEGACY_OS && arch_type == 1 ){ echo "Run inductor codegen tests" sh """ - pip install --verbose . - pytest python/test/test_gen_instances.py + python3 -m venv ${env.WORKSPACE} + . ${env.WORKSPACE}/bin/activate + python3 -m pip install pytest build setuptools setuptools_scm + python3 -m pip install . + python3 -m pytest python/test/test_gen_instances.py """ } dir("build"){ - if (params.RUN_FULL_QA && arch_type == 1 ){ + if (params.RUN_FULL_QA && arch_type == 2 ){ // build deb packages for all gfx9 targets on gfx90a system and prepare to export echo "Build ckProfiler package" sh 'make -j package' - archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' - sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash includes: "ckprofiler_0.2.0_amd64.deb", name: "ckprofiler_0.2.0_amd64.deb" + archiveArtifacts artifacts: 'composablekernel*.deb' + sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' + sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' + sh 'mv composablekernel-examples_*.deb composablekernel-examples_1.1.0_amd64.deb' + sh 'mv composablekernel-tests_*.deb composablekernel-tests_1.1.0_amd64.deb' + stash includes: "composablekernel-**.deb", name: "packages" } } // run performance tests, stash the logs, results will be processed on the master node @@ -602,14 +620,11 @@ def Build_CK(Map conf=[:]){ stash includes: "perf_onnx_gemm_gfx12.log", name: "perf_log_gfx12" } else if ( arch_type == 6 ){ - // run standard tests on gfx908 + // run basic tests on gfx908 echo "Run performance tests" - sh "./run_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" - archiveArtifacts "perf_gemm_gfx908.log" + sh "./run_gemm_performance_tests.sh 0 CI_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME} gfx908" archiveArtifacts "perf_onnx_gemm_gfx908.log" - archiveArtifacts "perf_resnet50_N256_gfx908.log" - archiveArtifacts "perf_resnet50_N4_gfx908.log" - stash includes: "perf_**.log", name: "perf_log_gfx908" + stash includes: "perf_onnx_gemm_gfx908.log", name: "perf_log_gfx908" } } } @@ -631,10 +646,6 @@ def Build_CK(Map conf=[:]){ """ } } - // set ownership of all files and folders to jenkins after all steps completed - dir("build"){ - sh "sudo chown -R jenkins:jenkins ../*" - } } } } @@ -660,7 +671,8 @@ def Build_CK_and_Reboot(Map conf=[:]){ def process_results(Map conf=[:]){ env.HSA_ENABLE_SDMA=0 checkout scm - def image = getDockerImageName() + //use older image that has user jenkins + def image = "rocm/composable_kernel:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" // Jenkins is complaining about the render group @@ -673,12 +685,17 @@ def process_results(Map conf=[:]){ def retimage gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { - try { - (retimage, image) = getDockerImage(conf) + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e + catch(Exception ex) + { + error "Unable to locate image: ${image}" } } @@ -695,6 +712,15 @@ def process_results(Map conf=[:]){ echo "could not locate the FMHA performance logs: ${err.getMessage()}." } } + if (params.RUN_CK_TILE_TRANSPOSE_TESTS){ + try{ + unstash "perf_transpose_log_gfx942" + unstash "perf_transpose_log_gfx90a" + } + catch(Exception err){ + echo "could not locate the Transpose performance logs: ${err.getMessage()}." + } + } if (params.RUN_CK_TILE_GEMM_TESTS){ try{ unstash "perf_tile_gemm_log_gfx942" @@ -706,9 +732,14 @@ def process_results(Map conf=[:]){ } if (params.RUN_FULL_QA){ // unstash perf files to master - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" - unstash "perf_log" + unstash "packages" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no composablekernel-*.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" + try{ + unstash "perf_log" + } + catch(Exception err){ + echo "could not locate perf_log: ${err.getMessage()}." + } try{ unstash "perf_log_gfx11" unstash "perf_log_gfx12" @@ -745,9 +776,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.3;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true - 0 22 * * * % ROCMVERSION=6.3;BUILD_GFX908=true;BUILD_GFX12=false;RUN_PERFORMANCE_TESTS=false - 0 21 * * * % ROCMVERSION=6.3;hipTensor_test=true;RUN_CODEGEN_TESTS=true +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;ROCMVERSION=6.4;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true + 0 21 * * * % ROCMVERSION=6.4;hipTensor_test=true;RUN_CODEGEN_TESTS=true;BUILD_GFX908=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false @@ -772,7 +802,7 @@ pipeline { description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', - defaultValue: '6.3', + defaultValue: '6.4', description: 'Specify which ROCM version to use: 6.3 (default).') string( name: 'COMPILER_VERSION', @@ -826,6 +856,10 @@ pipeline { name: "RUN_CK_TILE_FMHA_TESTS", defaultValue: false, description: "Run the ck_tile FMHA tests (default: OFF)") + booleanParam( + name: "RUN_CK_TILE_TRANSPOSE_TESTS", + defaultValue: false, + description: "Run the ck_tile Transpose tests (default: OFF)") booleanParam( name: "RUN_CK_TILE_GEMM_TESTS", defaultValue: false, @@ -850,6 +884,10 @@ pipeline { name: "BUILD_LEGACY_OS", defaultValue: false, description: "Try building CK with legacy OS dockers: RHEL8 and SLES15 (default: OFF)") + booleanParam( + name: "RUN_INDUCTOR_TESTS", + defaultValue: true, + description: "Run inductor codegen tests (default: ON)") } environment{ dbuser = "${dbuser}" @@ -944,8 +982,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1021,6 +1059,50 @@ pipeline { } } } + stage("Run CK_TILE_TRANSPOSE Tests") + { + parallel + { + stage("Run CK_TILE_TRANSPOSE Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx90a """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + stage("Run CK_TILE_TRANSPOSE Tests on gfx942") + { + when { + beforeAgent true + expression { params.RUN_CK_TILE_TRANSPOSE_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ ../script/cmake-ck-dev.sh ../ gfx942 && \ + make -j64 tile_example_batched_transpose && \ + cd ../ && + example/ck_tile/35_batched_transpose/script/run_full_test.sh "CI_${params.COMPILER_VERSION}" "${env.BRANCH_NAME}" "${NODE_NAME}" gfx942 """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Run CK_TILE_GEMM Tests") { parallel @@ -1109,28 +1191,6 @@ pipeline { } } stage("Build CK for all gfx9 targets") - { - when { - beforeAgent true - expression { params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } - } - agent{ label rocmnode("gfx90a") } - environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ - -DCMAKE_CXX_FLAGS=" -O3 " """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ - cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ - -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ - } - steps{ - Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') - cleanWs() - } - } - stage("Build CK and run Tests on gfx942") { when { beforeAgent true @@ -1138,7 +1198,9 @@ pipeline { } agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ + -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ @@ -1196,13 +1258,13 @@ pipeline { beforeAgent true expression { params.BUILD_INSTANCES_ONLY.toBoolean() && !params.RUN_FULL_QA.toBoolean() && !params.BUILD_LEGACY_OS.toBoolean() } } - agent{ label rocmnode("gfx90a") } + agent{ label rocmnode("gfx942") } environment{ execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ - -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j32 """ + -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1151;gfx1201" \ + -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/docs/conf.py b/docs/conf.py index e8617a09ef..fe8a1c1d79 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,6 +28,7 @@ external_toc_path = "./sphinx/_toc.yml" docs_core = ROCmDocs(left_nav_title) docs_core.run_doxygen(doxygen_root="doxygen", doxygen_path="doxygen/xml") +docs_core.enable_api_reference() docs_core.setup() external_projects_current_project = "composable_kernel" diff --git a/docs/doxygen/Doxyfile b/docs/doxygen/Doxyfile index fac9e138e1..4367aabc95 100644 --- a/docs/doxygen/Doxyfile +++ b/docs/doxygen/Doxyfile @@ -1,4 +1,4 @@ -# Doxyfile 1.8.10 +# Doxyfile 1.9.7 # This file describes the settings to be used by the documentation system # doxygen (www.doxygen.org) for a project. @@ -12,16 +12,26 @@ # For lists, items can also be appended using: # TAG += value [value, ...] # Values that contain spaces should be placed between quotes (\" \"). +# +# Note: +# +# Use doxygen to compare the used configuration file with the template +# configuration file: +# doxygen -x [configFile] +# Use doxygen to compare the used configuration file with the template +# configuration file without replacing the environment variables or CMake type +# replacement variables: +# doxygen -x_noenv [configFile] #--------------------------------------------------------------------------- # Project related configuration options #--------------------------------------------------------------------------- -# This tag specifies the encoding used for all characters in the config file -# that follow. The default is UTF-8 which is also the encoding used for all text -# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv -# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv -# for the list of possible encodings. +# This tag specifies the encoding used for all characters in the configuration +# file that follow. The default is UTF-8 which is also the encoding used for all +# text before the first occurrence of this tag. Doxygen uses libiconv (or the +# iconv built into libc) for the transcoding. See +# https://www.gnu.org/software/libiconv/ for the list of possible encodings. # The default value is: UTF-8. DOXYFILE_ENCODING = UTF-8 @@ -32,26 +42,26 @@ DOXYFILE_ENCODING = UTF-8 # title of most generated pages and in a few other places. # The default value is: My Project. -PROJECT_NAME = "ck" +PROJECT_NAME = "Composable Kernel" # The PROJECT_NUMBER tag can be used to enter a project or revision number. This # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = v3.0.1.0 +PROJECT_NUMBER = # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a # quick idea about the purpose of the project. Keep the description short. -PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HiP" +PROJECT_BRIEF = "Prototype interfaces compatible with ROCm platform and HiP" # With the PROJECT_LOGO tag one can specify a logo or an icon that is included # in the documentation. The maximum height of the logo should not exceed 55 # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -60,16 +70,28 @@ PROJECT_LOGO = OUTPUT_DIRECTORY = . -# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- -# directories (in 2 levels) under the output directory of each output format and -# will distribute the generated files over these directories. Enabling this +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create up to 4096 +# sub-directories (in 2 levels) under the output directory of each output format +# and will distribute the generated files over these directories. Enabling this # option can be useful when feeding doxygen a huge amount of source files, where # putting all generated files in the same directory would otherwise causes -# performance problems for the file system. +# performance problems for the file system. Adapt CREATE_SUBDIRS_LEVEL to +# control the number of sub-directories. # The default value is: NO. CREATE_SUBDIRS = NO +# Controls the number of sub-directories that will be created when +# CREATE_SUBDIRS tag is set to YES. Level 0 represents 16 directories, and every +# level increment doubles the number of directories, resulting in 4096 +# directories at level 8 which is the default and also the maximum value. The +# sub-directories are organized in 2 levels, the first level always has a fixed +# number of 16 directories. +# Minimum value: 0, maximum value: 8, default value: 8. +# This tag requires that the tag CREATE_SUBDIRS is set to YES. + +CREATE_SUBDIRS_LEVEL = 8 + # If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII # characters to appear in the names of generated files. If set to NO, non-ASCII # characters will be escaped, for example _xE3_x81_x84 will be used for Unicode @@ -81,14 +103,14 @@ ALLOW_UNICODE_NAMES = NO # The OUTPUT_LANGUAGE tag is used to specify the language in which all # documentation generated by doxygen is written. Doxygen will use this # information to generate all constant output in the proper language. -# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, -# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), -# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, -# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), -# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, -# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, -# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, -# Ukrainian and Vietnamese. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Bulgarian, +# Catalan, Chinese, Chinese-Traditional, Croatian, Czech, Danish, Dutch, English +# (United States), Esperanto, Farsi (Persian), Finnish, French, German, Greek, +# Hindi, Hungarian, Indonesian, Italian, Japanese, Japanese-en (Japanese with +# English messages), Korean, Korean-en (Korean with English messages), Latvian, +# Lithuanian, Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, +# Romanian, Russian, Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, +# Swedish, Turkish, Ukrainian and Vietnamese. # The default value is: English. OUTPUT_LANGUAGE = English @@ -162,7 +184,8 @@ FULL_PATH_NAMES = YES # will be relative from the directory where doxygen is started. # This tag requires that the tag FULL_PATH_NAMES is set to YES. -STRIP_FROM_PATH = +#STRIP_FROM_PATH = +STRIP_FROM_PATH = /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/latest/ # The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the # path mentioned in the documentation of a class, which tells the reader which @@ -171,7 +194,8 @@ STRIP_FROM_PATH = # specify the list of include paths that are normally passed to the compiler # using the -I flag. -STRIP_FROM_INC_PATH = +STRIP_FROM_INC_PATH = + # If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but # less readable) file names. This can be useful is your file systems doesn't @@ -189,6 +213,16 @@ SHORT_NAMES = NO JAVADOC_AUTOBRIEF = NO +# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line +# such as +# /*************** +# as being the beginning of a Javadoc-style comment "banner". If set to NO, the +# Javadoc-style will behave just like regular comments and it will not be +# interpreted by doxygen. +# The default value is: NO. + +JAVADOC_BANNER = NO + # If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first # line (until the first dot) of a Qt-style comment as the brief description. If # set to NO, the Qt-style will behave just like regular Qt-style comments (thus @@ -209,6 +243,14 @@ QT_AUTOBRIEF = NO MULTILINE_CPP_IS_BRIEF = NO +# By default Python docstrings are displayed as preformatted text and doxygen's +# special commands cannot be used. By setting PYTHON_DOCSTRING to NO the +# doxygen's special commands can be used and the contents of the docstring +# documentation blocks is shown as doxygen documentation. +# The default value is: YES. + +PYTHON_DOCSTRING = YES + # If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the # documentation from any documented member that it re-implements. # The default value is: YES. @@ -232,20 +274,19 @@ TAB_SIZE = 4 # the documentation. An alias has the form: # name=value # For example adding -# "sideeffect=@par Side Effects:\n" +# "sideeffect=@par Side Effects:^^" # will allow you to put the command \sideeffect (or @sideeffect) in the # documentation, which will result in a user-defined paragraph with heading -# "Side Effects:". You can put \n's in the value part of an alias to insert -# newlines. +# "Side Effects:". Note that you cannot put \n's in the value part of an alias +# to insert newlines (in the resulting output). You can put ^^ in the value part +# of an alias to insert a newline as if a physical newline was in the original +# file. When you need a literal { or } or , in the value part of an alias you +# have to escape them by means of a backslash (\), this can lead to conflicts +# with the commands \{ and \} for these it is advised to use the version @{ and +# @} or use a double escape (\\{ and \\}) ALIASES = -# This tag can be used to specify a number of word-keyword mappings (TCL only). -# A mapping has the form "name=value". For example adding "class=itcl::class" -# will allow you to use the command class in the itcl::class meaning. - -TCL_SUBST = - # Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources # only. Doxygen will then generate output that is more tailored for C. For # instance, some of the names that are used will be different. The list of all @@ -274,28 +315,40 @@ OPTIMIZE_FOR_FORTRAN = NO OPTIMIZE_OUTPUT_VHDL = NO +# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice +# sources only. Doxygen will then generate output that is more tailored for that +# language. For instance, namespaces will be presented as modules, types will be +# separated into more groups, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_SLICE = NO + # Doxygen selects the parser to use depending on the extension of the files it # parses. With this tag you can assign which parser to use for a given # extension. Doxygen has a built-in mapping, but you can override or extend it # using this tag. The format is ext=language, where ext is a file extension, and -# language is one of the parsers supported by doxygen: IDL, Java, Javascript, -# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: -# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: -# Fortran. In the later case the parser tries to guess whether the code is fixed -# or free formatted code, this is the default for Fortran type files), VHDL. For -# instance to make doxygen treat .inc files as Fortran files (default is PHP), -# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# language is one of the parsers supported by doxygen: IDL, Java, JavaScript, +# Csharp (C#), C, C++, Lex, D, PHP, md (Markdown), Objective-C, Python, Slice, +# VHDL, Fortran (fixed format Fortran: FortranFixed, free formatted Fortran: +# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser +# tries to guess whether the code is fixed or free formatted code, this is the +# default for Fortran type files). For instance to make doxygen treat .inc files +# as Fortran files (default is PHP), and .f files as C (default is Fortran), +# use: inc=Fortran f=C. # # Note: For files without extension you can use no_extension as a placeholder. # # Note that for custom extensions you also need to set FILE_PATTERNS otherwise -# the files are not read by doxygen. +# the files are not read by doxygen. When specifying no_extension you should add +# * to the FILE_PATTERNS. +# +# Note see also the list of default file extension mappings. EXTENSION_MAPPING = # If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments # according to the Markdown format, which allows for more readable -# documentation. See http://daringfireball.net/projects/markdown/ for details. +# documentation. See https://daringfireball.net/projects/markdown/ for details. # The output of markdown processing is further processed by doxygen, so you can # mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in # case of backward compatibilities issues. @@ -303,6 +356,26 @@ EXTENSION_MAPPING = MARKDOWN_SUPPORT = YES +# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up +# to that level are automatically included in the table of contents, even if +# they do not have an id attribute. +# Note: This feature currently applies only to Markdown headings. +# Minimum value: 0, maximum value: 99, default value: 5. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +TOC_INCLUDE_HEADINGS = 5 + +# The MARKDOWN_ID_STYLE tag can be used to specify the algorithm used to +# generate identifiers for the Markdown headings. Note: Every identifier is +# unique. +# Possible values are: DOXYGEN Use a fixed 'autotoc_md' string followed by a +# sequence number starting at 0. and GITHUB Use the lower case version of title +# with any whitespace replaced by '-' and punctations characters removed.. +# The default value is: DOXYGEN. +# This tag requires that the tag MARKDOWN_SUPPORT is set to YES. + +MARKDOWN_ID_STYLE = DOXYGEN + # When enabled doxygen tries to link words that correspond to documented # classes, or namespaces to their corresponding documentation. Such a link can # be prevented in individual cases by putting a % sign in front of the word or @@ -328,7 +401,7 @@ BUILTIN_STL_SUPPORT = YES CPP_CLI_SUPPORT = NO # Set the SIP_SUPPORT tag to YES if your project consists of sip (see: -# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen # will parse them like normal C++ but will assume all classes use public instead # of private inheritance when no explicit protection keyword is present. # The default value is: NO. @@ -414,6 +487,27 @@ TYPEDEF_HIDES_STRUCT = YES LOOKUP_CACHE_SIZE = 0 +# The NUM_PROC_THREADS specifies the number of threads doxygen is allowed to use +# during processing. When set to 0 doxygen will based this on the number of +# cores available in the system. You can set it explicitly to a value larger +# than 0 to get more control over the balance between CPU load and processing +# speed. At this moment only the input processing can be done using multiple +# threads. Since this is still an experimental feature the default is set to 1, +# which effectively disables parallel processing. Please report any issues you +# encounter. Generating dot graphs in parallel is controlled by the +# DOT_NUM_THREADS setting. +# Minimum value: 0, maximum value: 32, default value: 1. + +NUM_PROC_THREADS = 1 + +# If the TIMESTAMP tag is set different from NO then each generated page will +# contain the date or date and time when the page was generated. Setting this to +# NO can help when comparing the output of multiple runs. +# Possible values are: YES, NO, DATETIME and DATE. +# The default value is: NO. + +TIMESTAMP = YES + #--------------------------------------------------------------------------- # Build related configuration options #--------------------------------------------------------------------------- @@ -434,6 +528,12 @@ EXTRACT_ALL = YES EXTRACT_PRIVATE = NO +# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual +# methods of a class will be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIV_VIRTUAL = NO + # If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal # scope will be included in the documentation. # The default value is: NO. @@ -471,6 +571,13 @@ EXTRACT_LOCAL_METHODS = NO EXTRACT_ANON_NSPACES = NO +# If this flag is set to YES, the name of an unnamed parameter in a declaration +# will be determined by the corresponding definition. By default unnamed +# parameters remain unnamed in the output. +# The default value is: YES. + +RESOLVE_UNNAMED_PARAMS = YES + # If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all # undocumented members inside documented classes or files. If set to NO these # members will be included in the various overviews, but no documentation @@ -482,14 +589,15 @@ HIDE_UNDOC_MEMBERS = NO # If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all # undocumented classes that are normally visible in the class hierarchy. If set # to NO, these classes will be included in the various overviews. This option -# has no effect if EXTRACT_ALL is enabled. +# will also hide undocumented C++ concepts if enabled. This option has no effect +# if EXTRACT_ALL is enabled. # The default value is: NO. HIDE_UNDOC_CLASSES = NO # If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend -# (class|struct|union) declarations. If set to NO, these declarations will be -# included in the documentation. +# declarations. If set to NO, these declarations will be included in the +# documentation. # The default value is: NO. HIDE_FRIEND_COMPOUNDS = NO @@ -508,12 +616,20 @@ HIDE_IN_BODY_DOCS = NO INTERNAL_DOCS = NO -# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file -# names in lower-case letters. If set to YES, upper-case letters are also -# allowed. This is useful if you have classes or files whose names only differ -# in case and if your file system supports case sensitive file names. Windows -# and Mac users are advised to set this option to NO. -# The default value is: system dependent. +# With the correct setting of option CASE_SENSE_NAMES doxygen will better be +# able to match the capabilities of the underlying filesystem. In case the +# filesystem is case sensitive (i.e. it supports files in the same directory +# whose names only differ in casing), the option must be set to YES to properly +# deal with such files in case they appear in the input. For filesystems that +# are not case sensitive the option should be set to NO to properly deal with +# output files written for symbols that only differ in casing, such as for two +# classes, one named CLASS and the other named Class, and to also support +# references to files without having to specify the exact matching casing. On +# Windows (including Cygwin) and MacOS, users should typically set this option +# to NO, whereas on Linux or other Unix flavors it should typically be set to +# YES. +# Possible values are: SYSTEM, NO and YES. +# The default value is: SYSTEM. CASE_SENSE_NAMES = NO @@ -531,6 +647,12 @@ HIDE_SCOPE_NAMES = NO HIDE_COMPOUND_REFERENCE= NO +# If the SHOW_HEADERFILE tag is set to YES then the documentation for a class +# will show which file needs to be included to use the class. +# The default value is: YES. + +SHOW_HEADERFILE = YES + # If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of # the files that are included by a file in the documentation of that file. # The default value is: YES. @@ -688,7 +810,8 @@ FILE_VERSION_FILTER = # output files in an output format independent way. To create the layout file # that represents doxygen's defaults, run doxygen with the -l option. You can # optionally specify a file name after the option, if omitted DoxygenLayout.xml -# will be used as the name of the layout file. +# will be used as the name of the layout file. See also section "Changing the +# layout of pages" for information. # # Note that if you run doxygen from a directory containing a file called # DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE @@ -699,7 +822,7 @@ LAYOUT_FILE = # The CITE_BIB_FILES tag can be used to specify one or more bib files containing # the reference definitions. This must be a list of .bib files. The .bib # extension is automatically appended if omitted. This requires the bibtex tool -# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info. # For LaTeX the style of the bibliography can be controlled using # LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the # search path. See also \cite for info how to create references. @@ -734,34 +857,81 @@ WARNINGS = YES WARN_IF_UNDOCUMENTED = YES # If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for -# potential errors in the documentation, such as not documenting some parameters -# in a documented function, or documenting parameters that don't exist or using -# markup commands wrongly. +# potential errors in the documentation, such as documenting some parameters in +# a documented function twice, or documenting parameters that don't exist or +# using markup commands wrongly. # The default value is: YES. WARN_IF_DOC_ERROR = YES +# If WARN_IF_INCOMPLETE_DOC is set to YES, doxygen will warn about incomplete +# function parameter documentation. If set to NO, doxygen will accept that some +# parameters have no documentation without warning. +# The default value is: YES. + +WARN_IF_INCOMPLETE_DOC = YES + # This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that # are documented, but have no documentation for their parameters or return -# value. If set to NO, doxygen will only warn about wrong or incomplete -# parameter documentation, but not about the absence of documentation. +# value. If set to NO, doxygen will only warn about wrong parameter +# documentation, but not about the absence of documentation. If EXTRACT_ALL is +# set to YES then this flag will automatically be disabled. See also +# WARN_IF_INCOMPLETE_DOC # The default value is: NO. WARN_NO_PARAMDOC = NO +# If WARN_IF_UNDOC_ENUM_VAL option is set to YES, doxygen will warn about +# undocumented enumeration values. If set to NO, doxygen will accept +# undocumented enumeration values. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: NO. + +WARN_IF_UNDOC_ENUM_VAL = NO + +# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when +# a warning is encountered. If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS +# then doxygen will continue running as if WARN_AS_ERROR tag is set to NO, but +# at the end of the doxygen process doxygen will return with a non-zero status. +# If the WARN_AS_ERROR tag is set to FAIL_ON_WARNINGS_PRINT then doxygen behaves +# like FAIL_ON_WARNINGS but in case no WARN_LOGFILE is defined doxygen will not +# write the warning messages in between other messages but write them at the end +# of a run, in case a WARN_LOGFILE is defined the warning messages will be +# besides being in the defined file also be shown at the end of a run, unless +# the WARN_LOGFILE is defined as - i.e. standard output (stdout) in that case +# the behavior will remain as with the setting FAIL_ON_WARNINGS. +# Possible values are: NO, YES, FAIL_ON_WARNINGS and FAIL_ON_WARNINGS_PRINT. +# The default value is: NO. + +WARN_AS_ERROR = NO + # The WARN_FORMAT tag determines the format of the warning messages that doxygen # can produce. The string should contain the $file, $line, and $text tags, which # will be replaced by the file and line number from which the warning originated # and the warning text. Optionally the format may contain $version, which will # be replaced by the version of the file (if it could be obtained via # FILE_VERSION_FILTER) +# See also: WARN_LINE_FORMAT # The default value is: $file:$line: $text. WARN_FORMAT = "$file:$line: $text" +# In the $text part of the WARN_FORMAT command it is possible that a reference +# to a more specific place is given. To make it easier to jump to this place +# (outside of doxygen) the user can define a custom "cut" / "paste" string. +# Example: +# WARN_LINE_FORMAT = "'vi $file +$line'" +# See also: WARN_FORMAT +# The default value is: at line $line of file $file. + +WARN_LINE_FORMAT = "at line $line of file $file" + # The WARN_LOGFILE tag can be used to specify a file to which warning and error # messages should be written. If left blank the output is written to standard -# error (stderr). +# error (stderr). In case the file specified cannot be opened for writing the +# warning and error messages are written to standard error. When as file - is +# specified the warning and error messages are written to standard output +# (stdout). WARN_LOGFILE = @@ -779,18 +949,29 @@ INPUT = ../../include/ck/tensor_operation/gpu/grid \ ../../include/ck/tensor_operation/gpu/block \ ../../include/ck/tensor_operation/gpu/thread \ ../../library/include/ck/library/utility \ - ../../include/ck/wrapper - + ../../include/ck/wrapper \ + ../../include/ck_tile # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses # libiconv (or the iconv built into libc) for the transcoding. See the libiconv -# documentation (see: http://www.gnu.org/software/libiconv) for the list of -# possible encodings. +# documentation (see: +# https://www.gnu.org/software/libiconv/) for the list of possible encodings. +# See also: INPUT_FILE_ENCODING # The default value is: UTF-8. INPUT_ENCODING = UTF-8 +# This tag can be used to specify the character encoding of the source files +# that doxygen parses The INPUT_FILE_ENCODING tag can be used to specify +# character encoding on a per file pattern basis. Doxygen will compare the file +# name with each pattern and apply the encoding instead of the default +# INPUT_ENCODING) if there is a match. The character encodings are a list of the +# form: pattern=encoding (like *.php=ISO-8859-1). See cfg_input_encoding +# "INPUT_ENCODING" for further information on supported encodings. + +INPUT_FILE_ENCODING = + # If the value of the INPUT tag contains directories, you can use the # FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and # *.h) to filter out the source-files in the directories. @@ -799,11 +980,15 @@ INPUT_ENCODING = UTF-8 # need to set EXTENSION_MAPPING for the extension otherwise the files are not # read by doxygen. # +# Note the list of default checked file patterns might differ from the list of +# default file extension mappings. +# # If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, # *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, -# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, -# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, -# *.vhdl, *.ucf, *.qsf, *.as and *.js. +# *.hh, *.hxx, *.hpp, *.h++, *.l, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, +# *.inc, *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C +# comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f18, *.f, *.for, *.vhd, +# *.vhdl, *.ucf, *.qsf and *.ice. FILE_PATTERNS = *.c \ *.cc \ @@ -824,6 +1009,7 @@ FILE_PATTERNS = *.c \ *.hxx \ *.hpp \ *.h++ \ + *.l \ *.cs \ *.d \ *.php \ @@ -837,13 +1023,19 @@ FILE_PATTERNS = *.c \ *.mm \ *.dox \ *.py \ - *.tcl \ + *.pyw \ + *.f90 \ + *.f95 \ + *.f03 \ + *.f08 \ + *.f18 \ + *.f \ + *.for \ *.vhd \ *.vhdl \ *.ucf \ *.qsf \ - *.as \ - *.js + *.ice # The RECURSIVE tag can be used to specify whether or not subdirectories should # be searched for input files as well. @@ -880,10 +1072,7 @@ EXCLUDE_PATTERNS = # (namespaces, classes, functions, etc.) that should be excluded from the # output. The symbol name can be a fully qualified name, a word, or if the # wildcard * is used, a substring. Examples: ANamespace, AClass, -# AClass::ANamespace, ANamespace::*Test -# -# Note that the wildcards are matched against the file with absolute path, so to -# exclude all test directories use the pattern */test/* +# ANamespace::AClass, ANamespace::*Test EXCLUDE_SYMBOLS = @@ -927,6 +1116,15 @@ IMAGE_PATH = # Note that the filter must not add or remove lines; it is applied before the # code is scanned, but not when the output code is generated. If lines are added # or removed, the anchors will not be placed correctly. +# +# Note that doxygen will use the data processed and written to standard output +# for further processing, therefore nothing else, like debug statements or used +# commands (so in case of a Windows batch file always use @echo OFF), should be +# written to standard output. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. INPUT_FILTER = @@ -936,6 +1134,10 @@ INPUT_FILTER = # (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how # filters are used. If the FILTER_PATTERNS tag is empty or if none of the # patterns match the file name, INPUT_FILTER is applied. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# properly processed by doxygen. FILTER_PATTERNS = @@ -959,7 +1161,17 @@ FILTER_SOURCE_PATTERNS = # (index.html). This can be useful if you have a project on for instance GitHub # and want to reuse the introduction page also for the doxygen output. -USE_MDFILE_AS_MAINPAGE = ../README.md + +USE_MDFILE_AS_MAINPAGE = + +# The Fortran standard specifies that for fixed formatted Fortran code all +# characters from position 72 are to be considered as comment. A common +# extension is to allow longer lines before the automatic comment starts. The +# setting FORTRAN_COMMENT_AFTER will also make it possible that longer lines can +# be processed before the automatic comment starts. +# Minimum value: 7, maximum value: 10000, default value: 72. + +FORTRAN_COMMENT_AFTER = 72 #--------------------------------------------------------------------------- # Configuration options related to source browsing @@ -988,7 +1200,7 @@ INLINE_SOURCES = NO STRIP_CODE_COMMENTS = YES # If the REFERENCED_BY_RELATION tag is set to YES then for each documented -# function all documented functions referencing it will be listed. +# entity all documented functions referencing it will be listed. # The default value is: NO. REFERENCED_BY_RELATION = NO @@ -1020,12 +1232,12 @@ SOURCE_TOOLTIPS = YES # If the USE_HTAGS tag is set to YES then the references to source code will # point to the HTML generated by the htags(1) tool instead of doxygen built-in # source browser. The htags tool is part of GNU's global source tagging system -# (see http://www.gnu.org/software/global/global.html). You will need version +# (see https://www.gnu.org/software/global/global.html). You will need version # 4.8.6 or higher. # # To use it do the following: # - Install the latest version of global -# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file # - Make sure the INPUT points to the root of the source tree # - Run doxygen as normal # @@ -1047,25 +1259,6 @@ USE_HTAGS = NO VERBATIM_HEADERS = YES -# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the -# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the -# cost of reduced performance. This can be particularly helpful with template -# rich C++ code for which doxygen's built-in parser lacks the necessary type -# information. -# Note: The availability of this option depends on whether or not doxygen was -# compiled with the --with-libclang option. -# The default value is: NO. - -CLANG_ASSISTED_PARSING = NO - -# If clang assisted parsing is enabled you can provide the compiler with command -# line options that you would normally use when invoking the compiler. Note that -# the include paths will already be set by doxygen for the files and directories -# specified with INPUT and INCLUDE_PATH. -# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. - -CLANG_OPTIONS = - #--------------------------------------------------------------------------- # Configuration options related to the alphabetical class index #--------------------------------------------------------------------------- @@ -1077,17 +1270,11 @@ CLANG_OPTIONS = ALPHABETICAL_INDEX = YES -# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in -# which the alphabetical index list will be split. -# Minimum value: 1, maximum value: 20, default value: 5. -# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. - -COLS_IN_ALPHA_INDEX = 5 - -# In case all classes in a project start with a common prefix, all classes will -# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag -# can be used to specify a prefix (or a list of prefixes) that should be ignored -# while generating the index headers. +# The IGNORE_PREFIX tag can be used to specify a prefix (or a list of prefixes) +# that should be ignored while generating the index headers. The IGNORE_PREFIX +# tag works for classes, function and member names. The entity will be placed in +# the alphabetical list under the first letter of the entity name that remains +# after removing the prefix. # This tag requires that the tag ALPHABETICAL_INDEX is set to YES. IGNORE_PREFIX = @@ -1134,7 +1321,7 @@ HTML_FILE_EXTENSION = .html # of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_HEADER = +HTML_HEADER = ../_doxygen/header.html # The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each # generated HTML page. If the tag is left blank doxygen will generate a standard @@ -1144,7 +1331,7 @@ HTML_HEADER = # that doxygen normally uses. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_FOOTER = +HTML_FOOTER = ../_doxygen/footer.html # The HTML_STYLESHEET tag can be used to specify a user-defined cascading style # sheet that is used by each HTML page. It can be used to fine-tune the look of @@ -1156,7 +1343,7 @@ HTML_FOOTER = # obsolete. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_STYLESHEET = +HTML_STYLESHEET = ../_doxygen/stylesheet.css # The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined # cascading style sheets that are included after the standard style sheets @@ -1166,10 +1353,15 @@ HTML_STYLESHEET = # Doxygen will copy the style sheet files to the output directory. # Note: The order of the extra style sheet files is of importance (e.g. the last # style sheet in the list overrules the setting of the previous ones in the -# list). For an example see the documentation. +# list). +# Note: Since the styling of scrollbars can currently not be overruled in +# Webkit/Chromium, the styling will be left out of the default doxygen.css if +# one or more extra stylesheets have been specified. So if scrollbar +# customization is desired it has to be added explicitly. For an example see the +# documentation. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_STYLESHEET = +HTML_EXTRA_STYLESHEET = ../_doxygen/extra_stylesheet.css # The HTML_EXTRA_FILES tag can be used to specify one or more extra images or # other source files which should be copied to the HTML output directory. Note @@ -1179,21 +1371,34 @@ HTML_EXTRA_STYLESHEET = # files will be copied as-is; there are no commands or markers available. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_EXTRA_FILES = +HTML_EXTRA_FILES = ../_doxygen/extra_stylesheet.css + +# The HTML_COLORSTYLE tag can be used to specify if the generated HTML output +# should be rendered with a dark or light theme. +# Possible values are: LIGHT always generate light mode output, DARK always +# generate dark mode output, AUTO_LIGHT automatically set the mode according to +# the user preference, use light mode if no preference is set (the default), +# AUTO_DARK automatically set the mode according to the user preference, use +# dark mode if no preference is set and TOGGLE allow to user to switch between +# light and dark mode via a button. +# The default value is: AUTO_LIGHT. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE = LIGHT # The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen # will adjust the colors in the style sheet and background images according to -# this color. Hue is specified as an angle on a colorwheel, see -# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# this color. Hue is specified as an angle on a color-wheel, see +# https://en.wikipedia.org/wiki/Hue for more information. For instance the value # 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 # purple, and 360 is red again. # Minimum value: 0, maximum value: 359, default value: 220. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_COLORSTYLE_HUE = 220 +HTML_COLORSTYLE_HUE = 240 # The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors -# in the HTML output. For a value of 0 the output will use grayscales only. A +# in the HTML output. For a value of 0 the output will use gray-scales only. A # value of 255 will produce the most vivid colors. # Minimum value: 0, maximum value: 255, default value: 100. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1211,14 +1416,16 @@ HTML_COLORSTYLE_SAT = 100 HTML_COLORSTYLE_GAMMA = 80 -# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML -# page will contain the date and time when the page was generated. Setting this -# to YES can help to show when doxygen was last run and thus if the -# documentation is up to date. -# The default value is: NO. +# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML +# documentation will contain a main index with vertical navigation menus that +# are dynamically created via JavaScript. If disabled, the navigation index will +# consists of multiple levels of tabs that are statically embedded in every HTML +# page. Disable this option to support browsers that do not have JavaScript, +# like the Qt help browser. +# The default value is: YES. # This tag requires that the tag GENERATE_HTML is set to YES. -HTML_TIMESTAMP = NO +HTML_DYNAMIC_MENUS = YES # If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML # documentation will contain sections that can be hidden and shown after the @@ -1243,13 +1450,14 @@ HTML_INDEX_NUM_ENTRIES = 100 # If the GENERATE_DOCSET tag is set to YES, additional index files will be # generated that can be used as input for Apple's Xcode 3 integrated development -# environment (see: http://developer.apple.com/tools/xcode/), introduced with -# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a -# Makefile in the HTML output directory. Running make will produce the docset in -# that directory and running make install will install the docset in +# environment (see: +# https://developer.apple.com/xcode/), introduced with OSX 10.5 (Leopard). To +# create a documentation set, doxygen will generate a Makefile in the HTML +# output directory. Running make will produce the docset in that directory and +# running make install will install the docset in # ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at -# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html -# for more information. +# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy +# genXcode/_index.html for more information. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. @@ -1263,6 +1471,13 @@ GENERATE_DOCSET = NO DOCSET_FEEDNAME = "Doxygen generated docs" +# This tag determines the URL of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDURL = + # This tag specifies a string that should uniquely identify the documentation # set bundle. This should be a reverse domain-name style string, e.g. # com.mycompany.MyDocSet. Doxygen will append .docset to the name. @@ -1288,8 +1503,12 @@ DOCSET_PUBLISHER_NAME = Publisher # If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three # additional HTML index files: index.hhp, index.hhc, and index.hhk. The # index.hhp is a project file that can be read by Microsoft's HTML Help Workshop -# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on -# Windows. +# on Windows. In the beginning of 2021 Microsoft took the original page, with +# a.o. the download links, offline the HTML help workshop was already many years +# in maintenance mode). You can download the HTML help workshop from the web +# archives at Installation executable (see: +# http://web.archive.org/web/20160201063255/http://download.microsoft.com/downlo +# ad/0/A/9/0A939EF6-E31C-430F-A3DF-DFAE7960D564/htmlhelp.exe). # # The HTML Help Workshop contains a compiler that can convert all HTML output # generated by doxygen into a single compiled HTML file (.chm). Compiled HTML @@ -1319,7 +1538,7 @@ CHM_FILE = HHC_LOCATION = # The GENERATE_CHI flag controls if a separate .chi index file is generated -# (YES) or that it should be included in the master .chm file (NO). +# (YES) or that it should be included in the main .chm file (NO). # The default value is: NO. # This tag requires that the tag GENERATE_HTMLHELP is set to YES. @@ -1346,6 +1565,16 @@ BINARY_TOC = NO TOC_EXPAND = NO +# The SITEMAP_URL tag is used to specify the full URL of the place where the +# generated documentation will be placed on the server by the user during the +# deployment of the documentation. The generated sitemap is called sitemap.xml +# and placed on the directory specified by HTML_OUTPUT. In case no SITEMAP_URL +# is specified no sitemap is generated. For information about the sitemap +# protocol see https://www.sitemaps.org +# This tag requires that the tag GENERATE_HTML is set to YES. + +SITEMAP_URL = + # If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and # QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that # can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help @@ -1364,7 +1593,8 @@ QCH_FILE = # The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help # Project output. For more information please see Qt Help Project / Namespace -# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace). # The default value is: org.doxygen.Project. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1372,8 +1602,8 @@ QHP_NAMESPACE = org.doxygen.Project # The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt # Help Project output. For more information please see Qt Help Project / Virtual -# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- -# folders). +# Folders (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-folders). # The default value is: doc. # This tag requires that the tag GENERATE_QHP is set to YES. @@ -1381,30 +1611,30 @@ QHP_VIRTUAL_FOLDER = doc # If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom # filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_NAME = # The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the # custom filter to add. For more information please see Qt Help Project / Custom -# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- -# filters). +# Filters (see: +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-filters). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_CUST_FILTER_ATTRS = # The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this # project's filter section matches. Qt Help Project / Filter Attributes (see: -# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes). # This tag requires that the tag GENERATE_QHP is set to YES. QHP_SECT_FILTER_ATTRS = -# The QHG_LOCATION tag can be used to specify the location of Qt's -# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the -# generated .qhp file. +# The QHG_LOCATION tag can be used to specify the location (absolute path +# including file name) of Qt's qhelpgenerator. If non-empty doxygen will try to +# run qhelpgenerator on the generated .qhp file. # This tag requires that the tag GENERATE_QHP is set to YES. QHG_LOCATION = @@ -1447,16 +1677,28 @@ DISABLE_INDEX = NO # to work a browser that supports JavaScript, DHTML, CSS and frames is required # (i.e. any modern browser). Windows users are probably better off using the # HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can -# further fine-tune the look of the index. As an example, the default style -# sheet generated by doxygen has an example that shows how to put an image at -# the root of the tree instead of the PROJECT_NAME. Since the tree basically has -# the same information as the tab index, you could consider setting -# DISABLE_INDEX to YES when enabling this option. +# further fine tune the look of the index (see "Fine-tuning the output"). As an +# example, the default style sheet generated by doxygen has an example that +# shows how to put an image at the root of the tree instead of the PROJECT_NAME. +# Since the tree basically has the same information as the tab index, you could +# consider setting DISABLE_INDEX to YES when enabling this option. # The default value is: NO. # This tag requires that the tag GENERATE_HTML is set to YES. GENERATE_TREEVIEW = NO +# When both GENERATE_TREEVIEW and DISABLE_INDEX are set to YES, then the +# FULL_SIDEBAR option determines if the side bar is limited to only the treeview +# area (value NO) or if it should extend to the full height of the window (value +# YES). Setting this to YES gives a layout similar to +# https://docs.readthedocs.io with more room for contents, but less room for the +# project logo, title, and description. If either GENERATE_TREEVIEW or +# DISABLE_INDEX is set to NO, this option has no effect. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FULL_SIDEBAR = NO + # The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that # doxygen will group on one line in the generated HTML documentation. # @@ -1481,6 +1723,24 @@ TREEVIEW_WIDTH = 250 EXT_LINKS_IN_WINDOW = NO +# If the OBFUSCATE_EMAILS tag is set to YES, doxygen will obfuscate email +# addresses. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +OBFUSCATE_EMAILS = YES + +# If the HTML_FORMULA_FORMAT option is set to svg, doxygen will use the pdf2svg +# tool (see https://github.com/dawbarton/pdf2svg) or inkscape (see +# https://inkscape.org) to generate formulas as SVG images instead of PNGs for +# the HTML output. These images will generally look nicer at scaled resolutions. +# Possible values are: png (the default) and svg (looks nicer but requires the +# pdf2svg or inkscape tool). +# The default value is: png. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FORMULA_FORMAT = png + # Use this tag to change the font size of LaTeX formulas included as images in # the HTML documentation. When you change the font size after a successful # doxygen run you need to manually remove any form_*.png images from the HTML @@ -1490,19 +1750,14 @@ EXT_LINKS_IN_WINDOW = NO FORMULA_FONTSIZE = 10 -# Use the FORMULA_TRANPARENT tag to determine whether or not the images -# generated for formulas are transparent PNGs. Transparent PNGs are not -# supported properly for IE 6.0, but are supported on all modern browsers. -# -# Note that when changing this option you need to delete any form_*.png files in -# the HTML output directory before the changes have effect. -# The default value is: YES. -# This tag requires that the tag GENERATE_HTML is set to YES. +# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands +# to create new LaTeX commands to be used in formulas as building blocks. See +# the section "Including formulas" for details. -FORMULA_TRANSPARENT = YES +FORMULA_MACROFILE = # Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see -# http://www.mathjax.org) which uses client side Javascript for the rendering +# https://www.mathjax.org) which uses client side JavaScript for the rendering # instead of using pre-rendered bitmaps. Use this if you do not have LaTeX # installed or if you want to formulas look prettier in the HTML output. When # enabled you may also need to install MathJax separately and configure the path @@ -1512,11 +1767,29 @@ FORMULA_TRANSPARENT = YES USE_MATHJAX = YES +# With MATHJAX_VERSION it is possible to specify the MathJax version to be used. +# Note that the different versions of MathJax have different requirements with +# regards to the different settings, so it is possible that also other MathJax +# settings have to be changed when switching between the different MathJax +# versions. +# Possible values are: MathJax_2 and MathJax_3. +# The default value is: MathJax_2. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_VERSION = MathJax_2 + # When MathJax is enabled you can set the default output format to be used for -# the MathJax output. See the MathJax site (see: -# http://docs.mathjax.org/en/latest/output.html) for more details. +# the MathJax output. For more details about the output format see MathJax +# version 2 (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) and MathJax version 3 +# (see: +# http://docs.mathjax.org/en/latest/web/components/output.html). # Possible values are: HTML-CSS (which is slower, but has the best -# compatibility), NativeMML (i.e. MathML) and SVG. +# compatibility. This is the name for Mathjax version 2, for MathJax version 3 +# this will be translated into chtml), NativeMML (i.e. MathML. Only supported +# for NathJax 2. For MathJax version 3 chtml will be used instead.), chtml (This +# is the name for Mathjax version 3, for MathJax version 2 this will be +# translated into HTML-CSS) and SVG. # The default value is: HTML-CSS. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1529,22 +1802,29 @@ MATHJAX_FORMAT = HTML-CSS # MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax # Content Delivery Network so you can quickly see the result without installing # MathJax. However, it is strongly recommended to install a local copy of -# MathJax from http://www.mathjax.org before deployment. -# The default value is: http://cdn.mathjax.org/mathjax/latest. +# MathJax from https://www.mathjax.org before deployment. The default value is: +# - in case of MathJax version 2: https://cdn.jsdelivr.net/npm/mathjax@2 +# - in case of MathJax version 3: https://cdn.jsdelivr.net/npm/mathjax@3 # This tag requires that the tag USE_MATHJAX is set to YES. -MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest +MATHJAX_RELPATH = # The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax # extension names that should be enabled during MathJax rendering. For example +# for MathJax version 2 (see +# https://docs.mathjax.org/en/v2.7-latest/tex.html#tex-and-latex-extensions): # MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# For example for MathJax version 3 (see +# http://docs.mathjax.org/en/latest/input/tex/extensions/index.html): +# MATHJAX_EXTENSIONS = ams # This tag requires that the tag USE_MATHJAX is set to YES. MATHJAX_EXTENSIONS = # The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces # of code that will be used on startup of the MathJax code. See the MathJax site -# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# (see: +# http://docs.mathjax.org/en/v2.7-latest/output.html) for more details. For an # example see the documentation. # This tag requires that the tag USE_MATHJAX is set to YES. @@ -1572,7 +1852,7 @@ MATHJAX_CODEFILE = SEARCHENGINE = YES # When the SERVER_BASED_SEARCH tag is enabled the search engine will be -# implemented using a web server instead of a web client using Javascript. There +# implemented using a web server instead of a web client using JavaScript. There # are two flavors of web server based searching depending on the EXTERNAL_SEARCH # setting. When disabled, doxygen will generate a PHP script for searching and # an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing @@ -1591,7 +1871,8 @@ SERVER_BASED_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). +# Xapian (see: +# https://xapian.org/). # # See the section "External Indexing and Searching" for details. # The default value is: NO. @@ -1604,8 +1885,9 @@ EXTERNAL_SEARCH = NO # # Doxygen ships with an example indexer (doxyindexer) and search engine # (doxysearch.cgi) which are based on the open source search engine library -# Xapian (see: http://xapian.org/). See the section "External Indexing and -# Searching" for details. +# Xapian (see: +# https://xapian.org/). See the section "External Indexing and Searching" for +# details. # This tag requires that the tag SEARCHENGINE is set to YES. SEARCHENGINE_URL = @@ -1656,21 +1938,35 @@ LATEX_OUTPUT = latex # The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be # invoked. # -# Note that when enabling USE_PDFLATEX this option is only used for generating -# bitmaps for formulas in the HTML output, but not in the Makefile that is -# written to the output directory. -# The default file is: latex. +# Note that when not enabling USE_PDFLATEX the default is latex when enabling +# USE_PDFLATEX the default is pdflatex and when in the later case latex is +# chosen this is overwritten by pdflatex. For specific output languages the +# default can have been set differently, this depends on the implementation of +# the output language. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_CMD_NAME = latex # The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate # index for LaTeX. +# Note: This tag is used in the Makefile / make.bat. +# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file +# (.tex). # The default file is: makeindex. # This tag requires that the tag GENERATE_LATEX is set to YES. MAKEINDEX_CMD_NAME = makeindex +# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to +# generate index for LaTeX. In case there is no backslash (\) as first character +# it will be automatically added in the LaTeX code. +# Note: This tag is used in the generated output file (.tex). +# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat. +# The default value is: makeindex. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_MAKEINDEX_CMD = makeindex + # If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX # documents. This may be useful for small projects and may help to save some # trees in general. @@ -1700,29 +1996,31 @@ PAPER_TYPE = a4 EXTRA_PACKAGES = -# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the -# generated LaTeX document. The header should contain everything until the first -# chapter. If it is left blank doxygen will generate a standard header. See -# section "Doxygen usage" for information on how to let doxygen write the -# default header to a separate file. +# The LATEX_HEADER tag can be used to specify a user-defined LaTeX header for +# the generated LaTeX document. The header should contain everything until the +# first chapter. If it is left blank doxygen will generate a standard header. It +# is highly recommended to start with a default header using +# doxygen -w latex new_header.tex new_footer.tex new_stylesheet.sty +# and then modify the file new_header.tex. See also section "Doxygen usage" for +# information on how to generate the default header that doxygen normally uses. # -# Note: Only use a user-defined header if you know what you are doing! The -# following commands have a special meaning inside the header: $title, -# $datetime, $date, $doxygenversion, $projectname, $projectnumber, -# $projectbrief, $projectlogo. Doxygen will replace $title with the empty -# string, for the replacement values of the other commands the user is referred -# to HTML_HEADER. +# Note: Only use a user-defined header if you know what you are doing! +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. The following +# commands have a special meaning inside the header (and footer): For a +# description of the possible markers and block names see the documentation. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_HEADER = -# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the -# generated LaTeX document. The footer should contain everything after the last -# chapter. If it is left blank doxygen will generate a standard footer. See +# The LATEX_FOOTER tag can be used to specify a user-defined LaTeX footer for +# the generated LaTeX document. The footer should contain everything after the +# last chapter. If it is left blank doxygen will generate a standard footer. See # LATEX_HEADER for more information on how to generate a default footer and what -# special commands can be used inside the footer. -# -# Note: Only use a user-defined footer if you know what you are doing! +# special commands can be used inside the footer. See also section "Doxygen +# usage" for information on how to generate the default footer that doxygen +# normally uses. Note: Only use a user-defined footer if you know what you are +# doing! # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_FOOTER = @@ -1755,18 +2053,26 @@ LATEX_EXTRA_FILES = PDF_HYPERLINKS = YES -# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate -# the PDF file directly from the LaTeX files. Set this option to YES, to get a -# higher quality PDF documentation. +# If the USE_PDFLATEX tag is set to YES, doxygen will use the engine as +# specified with LATEX_CMD_NAME to generate the PDF file directly from the LaTeX +# files. Set this option to YES, to get a higher quality PDF documentation. +# +# See also section LATEX_CMD_NAME for selecting the engine. # The default value is: YES. # This tag requires that the tag GENERATE_LATEX is set to YES. USE_PDFLATEX = YES -# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode -# command to the generated LaTeX files. This will instruct LaTeX to keep running -# if errors occur, instead of asking the user for help. This option is also used -# when generating formulas in HTML. +# The LATEX_BATCHMODE tag ignals the behavior of LaTeX in case of an error. +# Possible values are: NO same as ERROR_STOP, YES same as BATCH, BATCH In batch +# mode nothing is printed on the terminal, errors are scrolled as if is +# hit at every error; missing files that TeX tries to input or request from +# keyboard input (\read on a not open input stream) cause the job to abort, +# NON_STOP In nonstop mode the diagnostic message will appear on the terminal, +# but there is no possibility of user interaction just like in batch mode, +# SCROLL In scroll mode, TeX will stop only for missing files to input or if +# keyboard input is necessary and ERROR_STOP In errorstop mode, TeX will stop at +# each error, asking for user intervention. # The default value is: NO. # This tag requires that the tag GENERATE_LATEX is set to YES. @@ -1779,24 +2085,22 @@ LATEX_BATCHMODE = NO LATEX_HIDE_INDICES = NO -# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source -# code with syntax highlighting in the LaTeX output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_LATEX is set to YES. - -LATEX_SOURCE_CODE = NO - # The LATEX_BIB_STYLE tag can be used to specify the style to use for the # bibliography, e.g. plainnat, or ieeetr. See -# http://en.wikipedia.org/wiki/BibTeX and \cite for more info. +# https://en.wikipedia.org/wiki/BibTeX and \cite for more info. # The default value is: plain. # This tag requires that the tag GENERATE_LATEX is set to YES. LATEX_BIB_STYLE = plain +# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute) +# path from which the emoji images will be read. If a relative path is entered, +# it will be relative to the LATEX_OUTPUT directory. If left blank the +# LATEX_OUTPUT directory will be used. +# This tag requires that the tag GENERATE_LATEX is set to YES. + +LATEX_EMOJI_DIRECTORY = + #--------------------------------------------------------------------------- # Configuration options related to the RTF output #--------------------------------------------------------------------------- @@ -1836,9 +2140,9 @@ COMPACT_RTF = NO RTF_HYPERLINKS = NO -# Load stylesheet definitions from file. Syntax is similar to doxygen's config -# file, i.e. a series of assignments. You only have to provide replacements, -# missing definitions are set to their default value. +# Load stylesheet definitions from file. Syntax is similar to doxygen's +# configuration file, i.e. a series of assignments. You only have to provide +# replacements, missing definitions are set to their default value. # # See also section "Doxygen usage" for information on how to generate the # default style sheet that doxygen normally uses. @@ -1847,22 +2151,12 @@ RTF_HYPERLINKS = NO RTF_STYLESHEET_FILE = # Set optional variables used in the generation of an RTF document. Syntax is -# similar to doxygen's config file. A template extensions file can be generated -# using doxygen -e rtf extensionFile. +# similar to doxygen's configuration file. A template extensions file can be +# generated using doxygen -e rtf extensionFile. # This tag requires that the tag GENERATE_RTF is set to YES. RTF_EXTENSIONS_FILE = -# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code -# with syntax highlighting in the RTF output. -# -# Note that which sources are shown also depends on other settings such as -# SOURCE_BROWSER. -# The default value is: NO. -# This tag requires that the tag GENERATE_RTF is set to YES. - -RTF_SOURCE_CODE = NO - #--------------------------------------------------------------------------- # Configuration options related to the man page output #--------------------------------------------------------------------------- @@ -1934,6 +2228,13 @@ XML_OUTPUT = xml XML_PROGRAMLISTING = YES +# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include +# namespace members in file scope as well, matching the HTML output. +# The default value is: NO. +# This tag requires that the tag GENERATE_XML is set to YES. + +XML_NS_MEMB_FILE_SCOPE = NO + #--------------------------------------------------------------------------- # Configuration options related to the DOCBOOK output #--------------------------------------------------------------------------- @@ -1952,23 +2253,14 @@ GENERATE_DOCBOOK = NO DOCBOOK_OUTPUT = docbook -# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the -# program listings (including syntax highlighting and cross-referencing -# information) to the DOCBOOK output. Note that enabling this will significantly -# increase the size of the DOCBOOK output. -# The default value is: NO. -# This tag requires that the tag GENERATE_DOCBOOK is set to YES. - -DOCBOOK_PROGRAMLISTING = NO - #--------------------------------------------------------------------------- # Configuration options for the AutoGen Definitions output #--------------------------------------------------------------------------- # If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an -# AutoGen Definitions (see http://autogen.sf.net) file that captures the -# structure of the code including all documentation. Note that this feature is -# still experimental and incomplete at the moment. +# AutoGen Definitions (see https://autogen.sourceforge.net/) file that captures +# the structure of the code including all documentation. Note that this feature +# is still experimental and incomplete at the moment. # The default value is: NO. GENERATE_AUTOGEN_DEF = NO @@ -2047,7 +2339,8 @@ SEARCH_INCLUDES = NO # The INCLUDE_PATH tag can be used to specify one or more directories that # contain include files that are not input files but should be processed by the -# preprocessor. +# preprocessor. Note that the INCLUDE_PATH is not recursive, so the setting of +# RECURSIVE has no effect here. # This tag requires that the tag SEARCH_INCLUDES is set to YES. INCLUDE_PATH = @@ -2136,41 +2429,10 @@ EXTERNAL_GROUPS = YES EXTERNAL_PAGES = YES -# The PERL_PATH should be the absolute path and name of the perl script -# interpreter (i.e. the result of 'which perl'). -# The default file (with absolute path) is: /usr/bin/perl. - -PERL_PATH = /usr/bin/perl - #--------------------------------------------------------------------------- -# Configuration options related to the dot tool +# Configuration options related to diagram generator tools #--------------------------------------------------------------------------- -# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram -# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to -# NO turns the diagrams off. Note that this option also works with HAVE_DOT -# disabled, but it is recommended to install and use dot, since it yields more -# powerful graphs. -# The default value is: YES. - -CLASS_DIAGRAMS = NO - -# You can define message sequence charts within doxygen comments using the \msc -# command. Doxygen will then run the mscgen tool (see: -# http://www.mcternan.me.uk/mscgen/)) to produce the chart and insert it in the -# documentation. The MSCGEN_PATH tag allows you to specify the directory where -# the mscgen tool resides. If left empty the tool is assumed to be found in the -# default search path. - -MSCGEN_PATH = - -# You can include diagrams made with dia in doxygen documentation. Doxygen will -# then run dia to produce the diagram and insert it in the documentation. The -# DIA_PATH tag allows you to specify the directory where the dia binary resides. -# If left empty dia is assumed to be found in the default search path. - -DIA_PATH = - # If set to YES the inheritance and collaboration graphs will hide inheritance # and usage relations if the target is undocumented or is not a class. # The default value is: YES. @@ -2179,7 +2441,7 @@ HIDE_UNDOC_RELATIONS = YES # If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is # available from the path. This tool is part of Graphviz (see: -# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent +# https://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent # Bell Labs. The other options in this section have no effect if this option is # set to NO # The default value is: NO. @@ -2196,35 +2458,52 @@ HAVE_DOT = NO DOT_NUM_THREADS = 0 -# When you want a differently looking font in the dot files that doxygen -# generates you can specify the font name using DOT_FONTNAME. You need to make -# sure dot is able to find the font, which can be done by putting it in a -# standard location or by setting the DOTFONTPATH environment variable or by -# setting DOT_FONTPATH to the directory containing the font. -# The default value is: Helvetica. +# DOT_COMMON_ATTR is common attributes for nodes, edges and labels of +# subgraphs. When you want a differently looking font in the dot files that +# doxygen generates you can specify fontname, fontcolor and fontsize attributes. +# For details please see Node, +# Edge and Graph Attributes specification You need to make sure dot is able +# to find the font, which can be done by putting it in a standard location or by +# setting the DOTFONTPATH environment variable or by setting DOT_FONTPATH to the +# directory containing the font. Default graphviz fontsize is 14. +# The default value is: fontname=Helvetica,fontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTNAME = Helvetica +DOT_COMMON_ATTR = "fontname=Helvetica,fontsize=10" -# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of -# dot graphs. -# Minimum value: 4, maximum value: 24, default value: 10. +# DOT_EDGE_ATTR is concatenated with DOT_COMMON_ATTR. For elegant style you can +# add 'arrowhead=open, arrowtail=open, arrowsize=0.5'. Complete documentation about +# arrows shapes. +# The default value is: labelfontname=Helvetica,labelfontsize=10. # This tag requires that the tag HAVE_DOT is set to YES. -DOT_FONTSIZE = 10 +DOT_EDGE_ATTR = "labelfontname=Helvetica,labelfontsize=10" -# By default doxygen will tell dot to use the default font as specified with -# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set -# the path where dot can find it using this tag. +# DOT_NODE_ATTR is concatenated with DOT_COMMON_ATTR. For view without boxes +# around nodes set 'shape=plain' or 'shape=plaintext' Shapes specification +# The default value is: shape=box,height=0.2,width=0.4. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_NODE_ATTR = "shape=box,height=0.2,width=0.4" + +# You can set the path where dot can find font specified with fontname in +# DOT_COMMON_ATTR and others dot attributes. # This tag requires that the tag HAVE_DOT is set to YES. DOT_FONTPATH = -# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for -# each documented class showing the direct and indirect inheritance relations. -# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO. +# If the CLASS_GRAPH tag is set to YES or GRAPH or BUILTIN then doxygen will +# generate a graph for each documented class showing the direct and indirect +# inheritance relations. In case the CLASS_GRAPH tag is set to YES or GRAPH and +# HAVE_DOT is enabled as well, then dot will be used to draw the graph. In case +# the CLASS_GRAPH tag is set to YES and HAVE_DOT is disabled or if the +# CLASS_GRAPH tag is set to BUILTIN, then the built-in generator will be used. +# If the CLASS_GRAPH tag is set to TEXT the direct and indirect inheritance +# relations will be shown as texts / links. +# Possible values are: NO, YES, TEXT, GRAPH and BUILTIN. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. CLASS_GRAPH = YES @@ -2238,7 +2517,8 @@ CLASS_GRAPH = YES COLLABORATION_GRAPH = YES # If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for -# groups, showing the direct groups dependencies. +# groups, showing the direct groups dependencies. See also the chapter Grouping +# in the manual. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. @@ -2261,10 +2541,32 @@ UML_LOOK = NO # but if the number exceeds 15, the total amount of fields shown is limited to # 10. # Minimum value: 0, maximum value: 100, default value: 10. -# This tag requires that the tag HAVE_DOT is set to YES. +# This tag requires that the tag UML_LOOK is set to YES. UML_LIMIT_NUM_FIELDS = 10 +# If the DOT_UML_DETAILS tag is set to NO, doxygen will show attributes and +# methods without types and arguments in the UML graphs. If the DOT_UML_DETAILS +# tag is set to YES, doxygen will add type and arguments for attributes and +# methods in the UML graphs. If the DOT_UML_DETAILS tag is set to NONE, doxygen +# will not generate fields with class member information in the UML graphs. The +# class diagrams will look similar to the default class diagrams but using UML +# notation for the relationships. +# Possible values are: NO, YES and NONE. +# The default value is: NO. +# This tag requires that the tag UML_LOOK is set to YES. + +DOT_UML_DETAILS = NO + +# The DOT_WRAP_THRESHOLD tag can be used to set the maximum number of characters +# to display on a single line. If the actual line length exceeds this threshold +# significantly it will wrapped across multiple lines. Some heuristics are apply +# to avoid ugly line breaks. +# Minimum value: 0, maximum value: 1000, default value: 17. +# This tag requires that the tag HAVE_DOT is set to YES. + +DOT_WRAP_THRESHOLD = 17 + # If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and # collaboration graphs will show the relations between templates and their # instances. @@ -2331,10 +2633,17 @@ GRAPHICAL_HIERARCHY = YES DIRECTORY_GRAPH = YES +# The DIR_GRAPH_MAX_DEPTH tag can be used to limit the maximum number of levels +# of child directories generated in directory dependency graphs by dot. +# Minimum value: 1, maximum value: 25, default value: 1. +# This tag requires that the tag DIRECTORY_GRAPH is set to YES. + +DIR_GRAPH_MAX_DEPTH = 1 + # The DOT_IMAGE_FORMAT tag can be used to set the image format of the images # generated by dot. For an explanation of the image formats see the section # output formats in the documentation of the dot tool (Graphviz (see: -# http://www.graphviz.org/)). +# https://www.graphviz.org/)). # Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order # to make the SVG files visible in IE 9+ (other browsers do not have this # requirement). @@ -2371,11 +2680,12 @@ DOT_PATH = DOTFILE_DIRS = -# The MSCFILE_DIRS tag can be used to specify one or more directories that -# contain msc files that are included in the documentation (see the \mscfile -# command). +# You can include diagrams made with dia in doxygen documentation. Doxygen will +# then run dia to produce the diagram and insert it in the documentation. The +# DIA_PATH tag allows you to specify the directory where the dia binary resides. +# If left empty dia is assumed to be found in the default search path. -MSCFILE_DIRS = +DIA_PATH = # The DIAFILE_DIRS tag can be used to specify one or more directories that # contain dia files that are included in the documentation (see the \diafile @@ -2384,13 +2694,18 @@ MSCFILE_DIRS = DIAFILE_DIRS = # When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the -# path where java can find the plantuml.jar file. If left blank, it is assumed -# PlantUML is not used or called during a preprocessing step. Doxygen will -# generate a warning when it encounters a \startuml command in this case and -# will not generate output for the diagram. +# path where java can find the plantuml.jar file or to the filename of jar file +# to be used. If left blank, it is assumed PlantUML is not used or called during +# a preprocessing step. Doxygen will generate a warning when it encounters a +# \startuml command in this case and will not generate output for the diagram. PLANTUML_JAR_PATH = +# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a +# configuration file for plantuml. + +PLANTUML_CFG_FILE = + # When using plantuml, the specified paths are searched for files specified by # the !include statement in a plantuml block. @@ -2420,18 +2735,6 @@ DOT_GRAPH_MAX_NODES = 50 MAX_DOT_GRAPH_DEPTH = 0 -# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent -# background. This is disabled by default, because dot on Windows does not seem -# to support this out of the box. -# -# Warning: Depending on the platform used, enabling this option may lead to -# badly anti-aliased labels on the edges of a graph (i.e. they become hard to -# read). -# The default value is: NO. -# This tag requires that the tag HAVE_DOT is set to YES. - -DOT_TRANSPARENT = NO - # Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output # files in one run (i.e. multiple -o and -T options on the command line). This # makes dot run faster, but since only newer versions of dot (>1.8.10) support @@ -2444,14 +2747,34 @@ DOT_MULTI_TARGETS = NO # If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page # explaining the meaning of the various boxes and arrows in the dot generated # graphs. +# Note: This tag requires that UML_LOOK isn't set, i.e. the doxygen internal +# graphical representation for inheritance and collaboration diagrams is used. # The default value is: YES. # This tag requires that the tag HAVE_DOT is set to YES. GENERATE_LEGEND = YES -# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot +# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate # files that are used to generate the various graphs. +# +# Note: This setting is not only used for dot files but also for msc temporary +# files. # The default value is: YES. -# This tag requires that the tag HAVE_DOT is set to YES. DOT_CLEANUP = YES + +# You can define message sequence charts within doxygen comments using the \msc +# command. If the MSCGEN_TOOL tag is left empty (the default), then doxygen will +# use a built-in version of mscgen tool to produce the charts. Alternatively, +# the MSCGEN_TOOL tag can also specify the name an external tool. For instance, +# specifying prog as the value, doxygen will call the tool as prog -T +# -o . The external tool should support +# output file formats "png", "eps", "svg", and "ismap". + +MSCGEN_TOOL = + +# The MSCFILE_DIRS tag can be used to specify one or more directories that +# contain msc files that are included in the documentation (see the \mscfile +# command). + +MSCFILE_DIRS = diff --git a/docs/index.rst b/docs/index.rst index 15a9321d43..4cc26a1d3e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,7 +10,7 @@ Composable Kernel User Guide The Composable Kernel library provides a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs and CPUs, through general purpose kernel languages such as `HIP C++ `_. -The Composable Kernel repository is located at `https://github.com/ROCm/composable-kernel `_. +The Composable Kernel repository is located at `https://github.com/ROCm/composable_kernel `_. .. grid:: 2 :gutter: 3 @@ -35,9 +35,9 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab * :doc:`Composable Kernel supported scalar types <./reference/Composable_Kernel_supported_scalar_types>` * :doc:`Composable Kernel custom types <./reference/Composable_Kernel_custom_types>` * :doc:`Composable Kernel vector utilities <./reference/Composable_Kernel_vector_utilities>` - * :ref:`api-reference` - * :ref:`wrapper` - + * :ref:`wrapper` + * :doc:`Composable Kernel complete class list <./doxygen/html/annotated>` + To contribute to the documentation refer to `Contributing to ROCm `_. You can find licensing information on the `Licensing `_ page. diff --git a/docs/reference/Composable-Kernel-API-reference.rst b/docs/reference/Composable-Kernel-API-reference.rst deleted file mode 100644 index b6ee9f7790..0000000000 --- a/docs/reference/Composable-Kernel-API-reference.rst +++ /dev/null @@ -1,42 +0,0 @@ -.. meta:: - :description: Composable Kernel documentation and API reference library - :keywords: composable kernel, CK, ROCm, API, documentation - -.. _api-reference: - -******************************************************************** -Composable Kernel API reference guide -******************************************************************** - -This document contains details of the APIs for the Composable Kernel library and introduces some of the key design principles that are used to write new classes that extend the functionality of the Composable Kernel library. - -================= -DeviceMem -================= - -.. doxygenstruct:: DeviceMem - -============================= -Kernels For Flashattention -============================= - -The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This section lists -the classes that are used in the CK GPU implementation of Flashattention. - -**Gridwise classes** - -.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle - -**Blockwise classes** - -.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 - -.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 - -.. doxygenstruct:: ck::BlockwiseSoftmax - -**Threadwise classes** - -.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic - -.. bibliography:: diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index df98998224..2ef3383d84 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -32,10 +32,10 @@ subtrees: title: Composable Kernel custom types - file: reference/Composable_Kernel_vector_utilities.rst title: Composable Kernel vector utilities - - file: reference/Composable-Kernel-API-reference.rst - title: Composable Kernel API reference - file: reference/Composable-Kernel-wrapper.rst - title: Composable Kernel Wrapper + title: Composable Kernel wrapper + - file: doxygen/html/annotated.rst + title: Composable Kernel class list - caption: About entries: diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 2fcf3b3935..6c48b2de09 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.18.1 +rocm-docs-core[api_reference]==1.18.4 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 12572d400e..62c3ea8ff8 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -6,68 +6,79 @@ # accessible-pygments==0.0.5 # via pydata-sphinx-theme -alabaster==0.7.16 +alabaster==1.0.0 # via sphinx asttokens==3.0.0 # via stack-data -attrs==24.3.0 +attrs==25.3.0 # via # jsonschema # jupyter-cache # referencing -babel==2.15.0 +babel==2.17.0 # via # pydata-sphinx-theme # sphinx -beautifulsoup4==4.12.3 +beautifulsoup4==4.13.4 # via pydata-sphinx-theme -breathe==4.35.0 +breathe==4.36.0 # via rocm-docs-core -certifi==2024.7.4 +certifi==2025.1.31 # via requests -cffi==1.16.0 +cffi==1.17.1 # via # cryptography # pynacl -charset-normalizer==3.3.2 +charset-normalizer==3.4.1 # via requests -click==8.1.7 +click==8.1.8 # via + # click-log + # doxysphinx # jupyter-cache # sphinx-external-toc +click-log==0.4.0 + # via doxysphinx comm==0.2.2 # via ipykernel -cryptography==43.0.0 +contourpy==1.3.2 + # via matplotlib +cryptography==44.0.2 # via pyjwt -debugpy==1.8.12 +cycler==0.12.1 + # via matplotlib +debugpy==1.8.14 # via ipykernel -decorator==5.1.1 +decorator==5.2.1 # via ipython -deprecated==1.2.14 +deprecated==1.2.18 # via pygithub docutils==0.21.2 # via - # breathe # myst-parser # pybtex-docutils # pydata-sphinx-theme # sphinx # sphinxcontrib-bibtex +doxysphinx==3.3.12 + # via rocm-docs-core exceptiongroup==1.2.2 # via ipython -executing==2.1.0 +executing==2.2.0 # via stack-data -fastjsonschema==2.20.0 +fastjsonschema==2.21.1 # via # nbformat # rocm-docs-core -gitdb==4.0.11 +fonttools==4.57.0 + # via matplotlib +gitdb==4.0.12 # via gitpython -gitpython==3.1.43 +gitpython==3.1.44 # via rocm-docs-core -greenlet==3.1.1 +greenlet==3.2.1 # via sqlalchemy -idna==3.7 +idna==3.10 # via requests imagesize==1.4.1 # via sphinx @@ -77,13 +88,13 @@ importlib-metadata==8.6.1 # myst-nb ipykernel==6.29.5 # via myst-nb -ipython==8.31.0 +ipython==8.35.0 # via # ipykernel # myst-nb jedi==0.19.2 # via ipython -jinja2==3.1.4 +jinja2==3.1.6 # via # myst-parser # sphinx @@ -103,25 +114,35 @@ jupyter-core==5.7.2 # jupyter-client # nbclient # nbformat +kiwisolver==1.4.8 + # via matplotlib latexcodec==3.0.0 # via pybtex +libsass==0.22.0 + # via doxysphinx +lxml==5.2.1 + # via doxysphinx markdown-it-py==3.0.0 # via # mdit-py-plugins # myst-parser -markupsafe==2.1.5 +markupsafe==3.0.2 # via jinja2 +matplotlib==3.10.1 + # via doxysphinx matplotlib-inline==0.1.7 # via # ipykernel # ipython -mdit-py-plugins==0.4.1 +mdit-py-plugins==0.4.2 # via myst-parser mdurl==0.1.2 # via markdown-it-py -myst-nb==1.1.2 +mpire==2.10.2 + # via doxysphinx +myst-nb==1.2.0 # via rocm-docs-core -myst-parser==3.0.1 +myst-parser==4.0.1 # via myst-nb nbclient==0.10.2 # via @@ -134,20 +155,28 @@ nbformat==5.10.4 # nbclient nest-asyncio==1.6.0 # via ipykernel -packaging==24.1 +numpy==1.26.4 + # via + # contourpy + # doxysphinx + # matplotlib +packaging==25.0 # via # ipykernel + # matplotlib # pydata-sphinx-theme # sphinx parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.3.6 +pillow==11.2.1 + # via matplotlib +platformdirs==4.3.7 # via jupyter-core -prompt-toolkit==3.0.50 +prompt-toolkit==3.0.51 # via ipython -psutil==6.1.1 +psutil==7.0.0 # via ipykernel ptyprocess==0.7.0 # via pexpect @@ -165,21 +194,30 @@ pydata-sphinx-theme==0.15.4 # via # rocm-docs-core # sphinx-book-theme -pygithub==2.3.0 +pygithub==2.6.1 # via rocm-docs-core -pygments==2.18.0 +pygments==2.19.1 # via # accessible-pygments # ipython + # mpire # pydata-sphinx-theme # sphinx -pyjwt[crypto]==2.8.0 +pyjson5==1.6.8 + # via doxysphinx +pyjwt[crypto]==2.10.1 # via pygithub pynacl==1.5.0 # via pygithub +pyparsing==3.2.3 + # via + # doxysphinx + # matplotlib python-dateutil==2.9.0.post0 - # via jupyter-client -pyyaml==6.0.1 + # via + # jupyter-client + # matplotlib +pyyaml==6.0.2 # via # jupyter-cache # myst-nb @@ -187,11 +225,11 @@ pyyaml==6.0.1 # pybtex # rocm-docs-core # sphinx-external-toc -pyzmq==26.2.0 +pyzmq==26.4.0 # via # ipykernel # jupyter-client -referencing==0.36.1 +referencing==0.36.2 # via # jsonschema # jsonschema-specifications @@ -199,23 +237,23 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.18.1 +rocm-docs-core[api-reference]==1.18.4 # via -r requirements.in -rpds-py==0.22.3 +rpds-py==0.24.0 # via # jsonschema # referencing -six==1.16.0 +six==1.17.0 # via # pybtex # python-dateutil -smmap==5.0.1 +smmap==5.0.2 # via gitdb snowballstemmer==2.2.0 # via sphinx -soupsieve==2.5 +soupsieve==2.7 # via beautifulsoup4 -sphinx==7.4.7 +sphinx==8.1.3 # via # breathe # myst-nb @@ -228,15 +266,15 @@ sphinx==7.4.7 # sphinx-external-toc # sphinx-notfound-page # sphinxcontrib-bibtex -sphinx-book-theme==1.1.3 +sphinx-book-theme==1.1.4 # via rocm-docs-core sphinx-copybutton==0.5.2 # via rocm-docs-core -sphinx-design==0.6.0 +sphinx-design==0.6.1 # via rocm-docs-core sphinx-external-toc==1.0.1 # via rocm-docs-core -sphinx-notfound-page==1.0.3 +sphinx-notfound-page==1.1.0 # via rocm-docs-core sphinxcontrib-applehelp==2.0.0 # via sphinx @@ -252,18 +290,20 @@ sphinxcontrib-qthelp==2.0.0 # via sphinx sphinxcontrib-serializinghtml==2.0.0 # via sphinx -sqlalchemy==2.0.37 +sqlalchemy==2.0.40 # via jupyter-cache stack-data==0.6.3 # via ipython tabulate==0.9.0 # via jupyter-cache -tomli==2.0.1 +tomli==2.2.1 # via sphinx tornado==6.4.2 # via # ipykernel # jupyter-client +tqdm==4.67.1 + # via mpire traitlets==5.14.3 # via # comm @@ -274,21 +314,22 @@ traitlets==5.14.3 # matplotlib-inline # nbclient # nbformat -typing-extensions==4.12.2 +typing-extensions==4.13.2 # via + # beautifulsoup4 # ipython # myst-nb # pydata-sphinx-theme # pygithub # referencing # sqlalchemy -urllib3==2.2.2 +urllib3==2.4.0 # via # pygithub # requests wcwidth==0.2.13 # via prompt-toolkit -wrapt==1.16.0 +wrapt==1.17.2 # via deprecated zipp==3.21.0 # via importlib-metadata diff --git a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp index 7c232f1bcf..7178ad46b9 100644 --- a/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_bf16_pk_i4_v3.cpp @@ -133,7 +133,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -192,14 +192,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp index 61c5a32d5d..e16f184a20 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3.cpp @@ -134,7 +134,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -242,14 +242,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp index 468dd699a1..f83d479713 100644 --- a/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp +++ b/example/01_gemm/gemm_xdl_fp16_pk_i4_v3_b_scale.cpp @@ -161,7 +161,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); @@ -274,14 +274,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index 80f7e95d30..266a1e9d3e 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -152,7 +152,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize() / + 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // do GEMM @@ -261,14 +262,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp index 7b72461dd9..0575314dff 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_v3.cpp @@ -132,7 +132,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // weight permute @@ -240,14 +240,20 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) b_element_op, c_element_op); - if(!gemm.IsSupportedArgument(argument) || ck::get_device_name() != "gfx942" || - ck::get_device_name() != "gfx950") + if(!gemm.IsSupportedArgument(argument)) { std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; return true; } + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + + return true; + } + bool pass = true; if(config.do_verification) { diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 86b3182a52..7186c22233 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -21,6 +21,7 @@ struct ExecutionConfig final bool do_verification = true; int init_method = 1; bool time_kernel = false; + bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) @@ -190,10 +191,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } - if(hargs_size > 0) + if(config.async_hargs && hargs_size > 0) { hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); - gemm.SetHostKernelArgs(&argument, gemm_hargs); + gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); } if(!gemm.IsSupportedArgument(argument)) @@ -203,16 +204,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co "not support this GEMM problem"); } - hipStream_t stream0 = nullptr; - hip_check_error(hipStreamCreate(&stream0)); + if(!config.async_hargs) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + } + else + { + hipStream_t stream0 = nullptr; + hip_check_error(hipStreamCreate(&stream0)); - hipEvent_t event0 = nullptr; - hip_check_error(hipEventCreate(&event0)); + hipEvent_t event0 = nullptr; + hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); - hip_check_error(hipEventSynchronize(event0)); - hip_check_error(hipStreamSynchronize(stream0)); + hip_check_error(hipEventSynchronize(event0)); + hip_check_error(hipStreamSynchronize(stream0)); + } bool pass = true; if(config.do_verification) @@ -280,18 +288,25 @@ bool run_grouped_gemm_example(int argc, char* argv[]) problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.async_hargs = std::stoi(argv[4]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: async hargs (0=n0, 1=yes)\n"); exit(0); } diff --git a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc index 8c4913dbcc..3582bc5e33 100644 --- a/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc +++ b/example/24_batched_gemm/run_batched_gemm_example_fp16int4_b_scale.inc @@ -212,7 +212,8 @@ bool run_batched_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem b_g_k_n_device_buf(sizeof(BDataType) * b_g_k_n_permute.mDesc.GetElementSpaceSize() / + 2); DeviceMem b1_g_scale_device_buf(sizeof(BScaleDataType) * b1_g_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_g_m_n_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpaceSize()); diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 95fd8bace8..8d51d43c65 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,17 +1,27 @@ add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_xdl_fp16_bpreshuffle gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) -list(APPEND gpu_list gfx942) +list(APPEND gpu_list gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp) add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp) + if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132) + set(EXAMPLE_COMPILE_OPTIONS) + check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY) + if(HAS_MAX_ILP_SCHEDULING_STRATEGY) + list(APPEND EXAMPLE_COMPILE_OPTIONS -mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1) + endif() + target_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + target_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) + endif() set(target 1) endif() endforeach() diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp new file mode 100644 index 0000000000..69803c7eeb --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16_bpreshuffle.cpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = F16; +using B0DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using A0Layout = Row; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(F16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()(BF16& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +void preShuffleBuffer(const F16* src, F16* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(F16); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MultiplyMultiply; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + // clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RCR + // kernel 1: 256->32x128x128 + < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, + 32, 128, 128, + 8, 8, + 32, 32, + 1, 1, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 16>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F16>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD = 0; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + KBatch = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, KBatch\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + d0_m_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + d0_device_buf.ToDevice(d0_m_n.mData.data()); + d1_device_buf.ToDevice(d1_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + constexpr auto I0 = ck::Number<0>{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer(), + d1_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{I0, I0}, + StrideE, + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 50, false, 1}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + invoker.Run(argument, StreamConfig{nullptr, false}); + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index e4e6a4f1a7..9f758d5fc5 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -9,7 +9,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" @@ -142,12 +141,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 128, 16, 16, - 32, 32, - 2, 2, + 16, 16, + 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp index 66825edcf9..3b31460953 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,17 +35,19 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = EDataType; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -59,35 +60,66 @@ struct MulABScale __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { - e = ck::type_convert(c * d1 * d0); + (void)d0; + (void)d1; + e = ck::type_convert(c); } -}; - -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu -{ - template - __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const { - // act - float x0 = 0; - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); - e = ck::type_convert(x0); + (void)d0; + (void)d1; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const EDataType& d0, const EDataType& d1) const + { + (void)d0; + (void)d1; + e = ck::type_convert(c); } }; -// using DsLayout = DsLayoutGate; -// using DsDataType = DsDataTypeGate; -using CDEElementOp = MulABScale; +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } +}; -// using CDEElementOp = MulABScaleSiluMulGate; +using CDEElementOp = MulABScaleExpertWeight; void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) { @@ -126,20 +158,21 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t NPerBlock = 64; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = true; +static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); static constexpr ck::index_t EVec = 16 / sizeof(EDataType); static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; -// using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm +static constexpr ck::index_t ActOP = 1; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, @@ -157,8 +190,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 2, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, true, int32_t, A0DataType>; // clang-format on @@ -170,15 +203,13 @@ int main(int argc, char* argv[]) // GEMM shape ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t K = 6144; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 8; - ck::index_t valid_tile_num = 8; - ck::index_t tokens = 128; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; + ck::index_t tokens = 64; ck::index_t topk = 2; - // ck::index_t tokens = batch * topk; - if(argc == 1) { // use default case @@ -224,28 +255,22 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{1, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + max_token_id.mData = {valid_size}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -259,48 +284,54 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; switch(init_method) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); - d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0, 1}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; case 3: - a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); - d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); @@ -310,16 +341,16 @@ int main(int argc, char* argv[]) DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k.savetxt("a.txt"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -330,7 +361,8 @@ int main(int argc, char* argv[]) int NPerXdl = device_op.GetPreShuffleParameters(); - preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, NPerXdl); + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -342,7 +374,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -368,9 +401,9 @@ int main(int argc, char* argv[]) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) * K * N * experts + + sizeof(B0DataType) * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -392,10 +425,13 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + ActOP, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -404,8 +440,11 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -428,15 +467,15 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 1102ce1054..9e80a2ca35 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -36,17 +36,19 @@ using A0DataType = F8; using B0DataType = I4; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; -using DsDataType = ck::Tuple; +using D2DataType = F32; +using DsDataType = ck::Tuple; using A0Layout = Row; using B0Layout = Col; using ELayout = Row; using D0Layout = Row; using D1Layout = Col; -using DsLayout = ck::Tuple; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; // for gate, a_scale, b_scale struct MulABScale @@ -55,43 +57,74 @@ struct MulABScale __host__ __device__ constexpr void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1) const + { + (void)d0; + (void)d1; +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c); +#else + e = ck::type_convert(c); +#endif + } template <> __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1) const { + (void)d0; + (void)d1; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d0 * 16); + e = ck::type_convert(c); #else - e = ck::type_convert(c * d1 * d0); + e = ck::type_convert(c); #endif } }; -// for gate, a_scale, b_scale, fuse silu, -struct MulABScaleSilu +struct MulABScaleExpertWeight { - template + template __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const; - + operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const; + // for real kernel use template <> - __host__ __device__ constexpr void operator()(EDataType& e, - const float& c, - const float& d0, - const float& d1) const + __host__ __device__ constexpr void operator()( + EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // act - float x0 = 0; -#if CK_USE_PK4_LAYOUT_SHUFFLE - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); -#else - ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); -#endif - e = ck::type_convert(x0); + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + // for reference cpu + template <> + __host__ __device__ constexpr void operator()( + float& e, const float& c, const float& d0, const float& d1, const float& d2) const + { + // for reference cpu + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } }; -using CDEElementOp = MulABScale; +static constexpr bool MulRoutedWeight = true; + +using CDEElementOp = MulABScaleExpertWeight; // combine MulRoutedWeight = true + +// using CDEElementOp = MulABScale; // combine MulRoutedWeight = true #if 1 void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl) @@ -132,53 +165,24 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -#if 0 -static constexpr ck::index_t MPerBlock = 64; -static constexpr ck::index_t MXDLPerWave = 1; -static constexpr ck::index_t NXDLPerWave = 2; -static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); -static constexpr ck::index_t Nswizzle = false; -static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); -static constexpr ck::index_t EVec = 16 / sizeof(EDataType); -static constexpr ck::index_t D0Vec = 1; -static constexpr ck::index_t D1Vec = 1; -// clang-format off -using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< - Row, Col, DsLayout, ELayout, - A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, - AElementOp, BElementOp, CDEElementOp, GemmSpec, - BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock, - AK1, BK1, - MNPerXDL, MNPerXDL, - MXDLPerWave, NXDLPerWave, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - MXDLPerWave, 1, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; -// clang-format on -#else static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t Act_OP = 1; // 0: gelu_and_mul, 1: silu_and_mul // clang-format off using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, MPerBlock, 128, 128, + 256, MPerBlock, 64, 128, 16, 32, - 32, 32, - 4, 1, + 16, 16, + 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 1, 1, S<1, 32, 1, 8>, S<8, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; + 2, 1, S<1, 32, 1, 8>, S<8, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Act_OP, Nswizzle, true, MulRoutedWeight, true, ck::index_t, A0DataType>; // clang-format on -#endif int main(int argc, char* argv[]) { @@ -186,13 +190,10 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape - ck::index_t N = 4096 * 2; - ck::index_t K = 6144; + ck::index_t N = 14336; + ck::index_t K = 4096; ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; @@ -232,20 +233,20 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0}; + constexpr auto StrideDs = std::array{1, 1, 1}; ck::index_t KBatch = 1; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); - max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 0, 0, 0}; + max_token_id.mData = {valid_size}; int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } - int token_per_tile = tokens * topk / valid_tile_num; + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; int tokenid = 0; for(int i = 0; i < sorted_size; i++) { @@ -260,17 +261,21 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); - Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); - Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); + Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); + Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K})); Tensor d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0})); - Tensor d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]})); + Tensor d1_e_n( + HostTensorDescriptor({experts, N * 2}, {StrideDs[1] * N * 2, StrideDs[1]})); + Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); Tensor e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); Tensor e_t_n_device_result( HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1})); std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl; std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl; std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl; std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; @@ -279,31 +284,35 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{}); break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); + DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); @@ -312,6 +321,7 @@ int main(int argc, char* argv[]) a0_device_buf.ToDevice(a0_t_k.mData.data()); d0_device_buf.ToDevice(d0_t_n.mData.data()); d1_device_buf.ToDevice(d1_e_n.mData.data()); + d2_device_buf.ToDevice(d2_e_n.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; @@ -424,7 +434,8 @@ int main(int argc, char* argv[]) a0_device_buf.GetDeviceBuffer(), b0_device_buf.GetDeviceBuffer(), std::array{d0_device_buf.GetDeviceBuffer(), - d1_device_buf.GetDeviceBuffer()}, + d1_device_buf.GetDeviceBuffer(), + d2_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), tokens, topk, @@ -440,20 +451,25 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || - !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * tokens * topk * N * K; + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + - sizeof(B0DataType) / 2 * K * N * experts + + sizeof(B0DataType) / 2 * K * N * 2 * experts + sizeof(EDataType) * valid_tile_num * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -475,10 +491,13 @@ int main(int argc, char* argv[]) using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm; + PassThrough, + Act_OP, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); @@ -487,8 +506,11 @@ int main(int argc, char* argv[]) max_token_id, MPerBlock, a0_t_k, + d0_t_n, b0_e_n_k, + d1_e_n, c_t_k_n, + d2_e_n, PassThrough{}, PassThrough{}, PassThrough{}); @@ -511,13 +533,14 @@ int main(int argc, char* argv[]) cde_element_op(e_t_n_host_result(t, topk_id, n), c_t_k_n(t, topk_id, n), d0_t_n(t, n), - d1_e_n(e, n)); + d1_e_n(e, n), + d2_e_n(e, n)); } } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); return ck::utils::check_err( - e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) ? 0 : 1; } diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp index 0d12441016..42d892fe26 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -25,7 +25,6 @@ template using S = ck::Sequence; using F16 = ck::half_t; -// using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using F32 = float; @@ -36,7 +35,7 @@ using A0DataType = F8; using B0DataType = F8; using EDataType = F16; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using D0DataType = F32; using D1DataType = F32; using D2DataType = F32; @@ -48,7 +47,6 @@ using ELayout = Row; using D0Layout = Row; using D1Layout = Col; using D2Layout = ELayout; -// using DsLayoutGate = ck::Tuple; using DsLayout = ck::Tuple; // d0: ascale, d1: bscale, d2:expert weight @@ -62,11 +60,19 @@ struct MulABScaleExpertWeight __host__ __device__ constexpr void operator()( EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { - // for real kernel use - // warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. - // tofix:felix (void)d0; - e = ck::type_convert(c * d1 * d2); + (void)d1; + (void)d2; + e = ck::type_convert(c); + } + template <> + __host__ __device__ constexpr void operator()( + EDataType& e, const EDataType& c, const float& d0, const float& d1, const float& d2) const + { + (void)d0; + (void)d1; + (void)d2; + e = ck::type_convert(c); } // for reference cpu template <> @@ -119,14 +125,12 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 4; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); -// static constexpr ck::index_t MXDLPerWave = MPerBlock / 32; //todo fix this constraint -// static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32; static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); @@ -135,6 +139,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off ///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -163,8 +168,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + 4, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; @@ -176,16 +181,13 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 6; - ck::index_t valid_tile_num = 6; + ck::index_t sorted_tile_num = 16; + ck::index_t valid_tile_num = 13; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 128; @@ -211,6 +213,18 @@ int main(int argc, char* argv[]) K = std::stoi(argv[5]); tokens = std::stoi(argv[6]); } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); @@ -228,15 +242,13 @@ int main(int argc, char* argv[]) ck::index_t KBatch = 1; - // const ck::index_t experts = 8; Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1})); - // max_token_id.mData[0] = valid_size; - // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; - max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; - int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + + max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; @@ -248,7 +260,7 @@ int main(int argc, char* argv[]) } int token_per_tile = tokens * topk / valid_tile_num; int tokenid = 0; - // sorted_token_ids.mData[0] = 0; + for(int i = 0; i < sorted_size; i++) { int tile_off = i % MPerBlock; @@ -262,8 +274,7 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + Tensor a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); @@ -314,12 +325,7 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - // d0_t_n.savetxt("d0_t_n.txt", "int"); - // d1_e_n.savetxt("d1_e_n.txt", "int"); - // d2_e_n.savetxt("d2_e_n.txt", "int"); + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -397,7 +403,7 @@ int main(int argc, char* argv[]) e_device_buf.ToDevice(e_t_n_device_result.mData.data()); invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); - Tensor c_t_n({tokens, N}); + Tensor c_t_n({tokens, N}); using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, @@ -437,8 +444,7 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - // e_t_n_device_result.savetxt("out.txt"); - // e_t_n_host_result.savetxt("ref.txt"); + return ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp index 528503a2c4..3745e3d0af 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_pk_i4.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -62,11 +62,13 @@ struct MulABScaleExpertWeight EDataType& e, const float& c, const float& d0, const float& d1, const float& d2) const { (void)d0; + (void)d1; + (void)d2; #if CK_USE_PK4_LAYOUT_SHUFFLE - e = ck::type_convert(c * d1 * d2 * 16); + e = ck::type_convert(c * 16); #else - e = ck::type_convert(c * d1 * d2); + e = ck::type_convert(c); #endif } // for reference cpu @@ -125,10 +127,10 @@ using CDEElementOp = MulABScaleExpertWeight; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; static constexpr ck::index_t BLOCKSIZE = 256; -static constexpr ck::index_t MXDLPerWave = 4; -static constexpr ck::index_t NXDLPerWave = 1; +static constexpr ck::index_t MXDLPerWave = 8; +static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t NPerBlock = 128; -static constexpr ck::index_t MNPerXDL = 32; +static constexpr ck::index_t MNPerXDL = 16; static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; @@ -138,6 +140,7 @@ static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; static constexpr ck::index_t D2Vec = 1; +static constexpr bool MulRoutedWeight = true; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, @@ -148,8 +151,8 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, - 1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; + 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, false, ck::index_t, A0DataType>; // clang-format on int main(int argc, char* argv[]) @@ -158,9 +161,6 @@ int main(int argc, char* argv[]) int init_method = 1; bool time_kernel = true; - // tokens = 1 - // topk = 1 - // experts = 8 // per expert: // GEMM shape ck::index_t N = 4096; @@ -281,7 +281,7 @@ int main(int argc, char* argv[]) break; case 4: a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); - b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); @@ -298,7 +298,7 @@ int main(int argc, char* argv[]) DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize() / 2); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); @@ -407,13 +407,18 @@ int main(int argc, char* argv[]) b_element_op, cde_element_op); - if(!device_op.IsSupportedArgument(argument) || - !(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + if(!device_op.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + + if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950")) + { + std::cout << "This kernel support gfx942 and gfx950 only" << std::endl; + } + if(time_kernel) { // not result correct here because output buf not setzero @@ -450,7 +455,8 @@ int main(int argc, char* argv[]) AccDataType, PassThrough, PassThrough, - CDEElementOp>; + CDEElementOp, + MulRoutedWeight>; auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 9e95c3e007..1a1db51c37 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,10 +1,11 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) -add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) +add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_bf8) + +add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) -add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index 713902588d..57b6490eda 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -10,16 +10,16 @@ Custom verification parameters: # arg4: verbosity (0=no info, 1=verbose info) # arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC # arg11: KBatch -./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 +./bin/example_gemm_mx_fp8 1 1 0 1 ``` Custom tensor shapes: ```bash -./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1 ``` Default invocation: ```bash -# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 -./bin/example_gemm_mx_fp8_fp8_scale +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_bf8.cpp new file mode 100644 index 0000000000..8e341fb591 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_bf8.cpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::bf8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t KPerBlock = 128; + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 128, // BlockSize: Thread block size + 128, // MPerBlock + 16, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 1, // NXdlPerWave + S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 2, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 9a05954c73..99ed2a23b9 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -95,7 +95,7 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl + << "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl << "arg11: KBatch" << std::endl; return false; } @@ -103,7 +103,8 @@ bool parse_cmd_args(int argc, return true; } -template + ck::index_t ScaleBlockSize> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; - - static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - - static constexpr ck::index_t KPerBlock = 64; - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - ADataType, // ADataType - XDataType, // AScaleDataType - BDataType, // BDataType - XDataType, // BScaleDataType - CDataType, // CDataType - AccDataType, // GemmAccDataType - CShuffleDataType, // CShuffleDataType - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - GemmSpec, // GemmSpec - MXVectorSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock - KPerBlock, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - BlkGemmPSched, // BlkGemmPipeSched - BlkGemmPVer, // BlkGemmPipelineVer - ADataType, // ComputeTypeA - BDataType // ComputeTypeB - >; auto M = problem_size.M; auto N = problem_size.N; @@ -230,8 +175,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor a_m_k_scale(f_host_tensor_descriptor( M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A @@ -290,7 +235,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c break; case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); @@ -428,8 +373,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + - std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of + // partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; @@ -445,7 +392,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c return res_verified; } -template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp new file mode 100644 index 0000000000..ce4ebc0a40 --- /dev/null +++ b/example/67_gemm_microscaling/gemm_mx_fp8_bf8.cpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_mx_common.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::bf8_t; + +using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = CDataType; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix +using CElementOp = PassThrough; // elementwise transformation for C matrix + +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size + +constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; +constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + ScaleBlockSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 256, // MPerBlock + 256, // NPerBlock + 128, // KPerBlock + 16, // AK1 + 8, // BK1 + 16, // MPerXDL + 16, // NPerXDL + 8, // MXdlPerWave + 8, // NXdlPerWave + S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp deleted file mode 100644 index 393f4a2ea7..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::e8m0_bexp_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp deleted file mode 100644 index dd654a8f69..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::half_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp deleted file mode 100644 index c42d9783be..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::f8_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 64ff2a6813..996a543ecc 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -114,14 +114,14 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 message("trimming targets for ${FILE_NAME}") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -212,7 +212,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..5b9d5742b4 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -114,12 +114,14 @@ LAYOUT_MAP = { PIPELINE_MAP = { "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", + "qs" : "ck_tile::BlockFmhaPipelineQSKSVS", } PIPELINE_ENUM_MAP = { "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS", } BOOL_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py new file mode 100644 index 0000000000..30b9299963 --- /dev/null +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -0,0 +1,595 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import copy +from dataclasses import dataclass +import fnmatch +import itertools +from pathlib import Path +from typing import List, Optional, Tuple + +from codegen.cmake_config import * +from codegen.cpp_symbol_map import * + + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +K0_MAX_SUBMAX_MAP = { + 32 : 32, + 64 : 64, + 96 : 128, + 128: 128, + 256: 256 +} + +FMHA_BATCH_PREFILL_PIPELINE_MAP = { + "qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync", +} + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape, + ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>, + ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>, + ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>, + {F_vlayout}>; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_logits}, + {F_bias}, + false, + {F_lse}, + {F_dropout}, + {F_squant}, + {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::RandValOutputDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_variant_{F_idx}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_batch_prefill_(const ck_tile::stream_config& s, fmha_batch_prefill_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp" +FMHA_FWD_API=""" +float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +{F_inner_dispatch} + }} +""" + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_batch_prefill_(s, a); + }} +""" + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0max : int + vlayout : str + logits : str + mask : str + bias : str # + lse : str # + dropout : str + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {bk0submax} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] + if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {bk0submax} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_logits : str # t/f + F_bias : str # true/false + F_lse : str # + F_dropout : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + else: n += '_npad' + + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + + if self.F_bias != 'no' : n += f'_{self.F_bias}' + else: n += '_nbias' + + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: n += '_nmask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + else: n += '_nmask' + + if self.F_lse == 't' : n += '_lse' + else: n += '_nlse' + + if self.F_dropout == 't' : n += '_dropout' + else: n += '_ndropout' + + if self.F_squant == 't' : n += '_squant' + else: n += '_nsquant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, + F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along k seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm0 : int # number of warps for gemm0 along q seqlen + F_rn0 : int # number of warps for gemm0 along k seqlen + F_rk0 : int # number of warps for gemm0 along head dim q (not used) + F_rm1 : int # number of warps for gemm1 along q seqlen + F_rn1 : int # number of warps for gemm1 along head dim v + F_rk1 : int # number of warps for gemm1 along k seqlen (not used) + F_wm0 : int # gemm0 warp size along m + F_wn0 : int # gemm0 warp size along n + F_wk0 : int # gemm0 warp size along k + F_wm1 : int # gemm1 warp size along m + F_wn1 : int # gemm1 warp size along n + F_wk1 : int # gemm1 warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\ + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\ + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = FWD_DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0max = self.F_tile.F_bk0max, + F_rm0 = self.F_tile.F_rm0, + F_rn0 = self.F_tile.F_rn0, + F_rk0 = self.F_tile.F_rk0, + F_rm1 = self.F_tile.F_rm1, + F_rn1 = self.F_tile.F_rn1, + F_rk1 = self.F_tile.F_rk1, + F_wm0 = self.F_tile.F_wm0, + F_wn0 = self.F_tile.F_wn0, + F_wk0 = self.F_tile.F_wk0, + F_wm1 = self.F_tile.F_wm1, + F_wn1 = self.F_tile.F_wn1, + F_wk1 = self.F_tile.F_wk1, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0max=self.F_tile.F_bk0max, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + dropout=self.F_pipeline.F_dropout, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: + if dtype == 'fp16' or dtype == 'bf16': + return { + ### '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + ### '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + ### '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + ### '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + } + else: + return None + +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) + elif dtype in ['fp8fp16', 'fp8bf16']: + # TODO + None + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for dtype in FWD_DTYPE_MAP.keys(): + d = get_fmha_fwd_tile_dict_from_dtype(dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + if hdim == 192 and tile.F_bn1 == 128: + # NOTE: this is used to speedup deepseek prefill case, we don't gen training + if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': + continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue + k = FmhaFwdKernel(F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != '': + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # 2 - Flash attention integration + if receipt in (2, 3): + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'alibi'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_bias in ['no', 'bias'] + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_fwd) integration + elif receipt == 100: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'batch' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # Aiter(mha_batch_prefill) integration + elif receipt == 200: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + # aiter::mha_batch_prefill C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == 'group' + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + write_single_fwd_kernel(kernel, output_dir) + write_fwd_api(api_pool, output_dir) + +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + with file_path.open('a') as f: + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6326a97f8e..80b64f918a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -60,6 +60,7 @@ using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + false, {F_bias}, {F_dbias}, false, @@ -545,6 +546,12 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= dpad == dvpad + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -682,6 +689,11 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB cond &= mode == "group" if not cond: continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -834,6 +846,11 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh cond &= mode == "group" if not cond: continue + # aiter::mha_bwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -850,9 +867,11 @@ def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autoge def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) for kernel in kernels: @@ -865,9 +884,11 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (3 - len(filter_list))) + # TODO + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index e5d11c6dc9..2f1287c87a 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -32,6 +32,7 @@ K0_MAX_SUBMAX_MAP = { FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py +#include "ck_tile/ops/fmha/block/variants.hpp" #include "fmha_fwd.hpp" """ @@ -51,12 +52,16 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, false, {F_lse}, {F_dropout}, {F_squant}, {F_occupancy}>; + +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; + using fmha_mask_{F_idx} = {F_mask}; using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< @@ -73,6 +78,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< typename FmhaFwdTypeConfig::ODataType, fmha_shape_{F_idx}, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait_{F_idx}>; @@ -88,7 +94,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdKernel; using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -123,9 +129,9 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); }} """ @@ -144,6 +150,7 @@ class FmhaFwdApiTrait: bk1 : int # tile size along kv gemm unroll bk0max : int vlayout : str + logits : str mask : str bias : str # lse : str # @@ -157,7 +164,7 @@ class FmhaFwdApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' @property def scheck(self) -> str: @@ -165,7 +172,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.spad == 't' : return 'true' # always support else : return 'true' - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False @@ -176,7 +183,7 @@ class FmhaFwdApiTrait: if self.pipeline_tag == 'qr_async': if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' - elif self.pipeline_tag in ['qr', 'qr_fp8']: + elif self.pipeline_tag in ['qr', 'qs']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' else: assert False @@ -187,7 +194,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dpad == 't': return f'a.hdim_q % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_q % {bk0submax} == 0' @@ -199,7 +206,7 @@ class FmhaFwdApiTrait: vec = int((32 * 4) / DTYPE_BITS[self.dtype]) if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' else : assert False - elif self.pipeline_tag in ['qr']: + elif self.pipeline_tag in ['qr', 'qs']: bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max] if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.hdim_v % {bk0submax} == 0' @@ -214,6 +221,7 @@ class FmhaFwdPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_dropout : str # @@ -235,6 +243,9 @@ class FmhaFwdPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -280,7 +291,7 @@ class FmhaFwdApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout] , F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -365,6 +376,7 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_dropout = BOOL_MAP[self.F_pipeline.F_dropout], @@ -399,6 +411,7 @@ class FmhaFwdKernel: bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, mask=self.F_pipeline.F_mask, + logits=self.F_pipeline.F_logits, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, dropout=self.F_pipeline.F_dropout, @@ -429,7 +442,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -440,33 +453,36 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): + for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): if hdim == 256: # if True: - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + # the below two is used for hdim vectorize load + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: if bias == "bias": # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', 'f', squant, mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -494,6 +510,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -504,6 +523,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt in (2, 3): cond = dtype in ['fp16', 'bf16'] @@ -536,6 +558,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= pipeline.F_squant == 'f' if not cond: continue + # aiter::mha_fwd C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -547,15 +576,15 @@ def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : str, receipt, mask_impl) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) +def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None: with file_path.open('a') as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f243020dc4..dc7ef712e2 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -343,13 +343,15 @@ def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> Non def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None: +def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: + assert optdim_list == [-1] with file_path.open('a') as f: _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index c6d1a01792..3ae0e28be3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -45,6 +45,7 @@ FMHA_FWD_SPLITKV_PIPELINE_MAP = { FMHA_FWD_SPLITKV_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; +using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>; using fmha_mask_{F_idx} = {F_mask}; namespace {{ @@ -63,6 +64,7 @@ using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, + {F_logits}, {F_bias}, /*kHasBiasGrad=*/false, {F_lse}, @@ -85,16 +87,19 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< typename FmhaFwdTypeConfig::OaccDataType, fmha_shape, {F_mode}, + fmha_variant_{F_idx}, fmha_mask_{F_idx}, fmha_trait>; using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = ck_tile::Default2DEpilogue::OaccDataType, typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - {F_spad}, {F_dvpad}>>; + false, false>>; using fmha_kernel = ck_tile::FmhaFwdSplitKVKernel; @@ -111,7 +116,7 @@ static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) }} using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; #include @@ -265,9 +270,9 @@ float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; @@ -308,6 +313,7 @@ class FmhaFwdSplitKVApiTrait: bk0max : int vlayout : str mask : str + logits : str bias : str # lse : str # squant : str # @@ -320,7 +326,7 @@ class FmhaFwdSplitKVApiTrait: @property def name(self) -> str: return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\ - f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ + f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\ f'{self.dvpad}-{self.pagedkv}' @property @@ -378,6 +384,7 @@ class FmhaFwdSplitKVPipeline: F_skpad : str # F_dpad : str # F_dvpad : str # + F_logits : str # t/f F_bias : str # true/false F_lse : str # F_squant : str # @@ -399,6 +406,9 @@ class FmhaFwdSplitKVPipeline: if pn != '' : n += f'_{pn}' else: n += '_npad' + if self.F_logits == 't' : n += '_logits' + else: n += '_nlogits' + if self.F_bias != 'no' : n += f'_{self.F_bias}' else: n += '_nbias' @@ -440,10 +450,10 @@ class FmhaFwdSplitKVCombinePipeline: n = f'{self.tag}' if pn != '' : n += f'_{pn}' else: n += '_npad' - + if self.F_lse == 't' : n += '_lse' else: n += '_nlse' - + if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' return n @@ -473,7 +483,7 @@ class FmhaFwdSplitKVApiPool: for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], - F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, @@ -539,6 +549,7 @@ class FmhaFwdSplitKVKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_logits = BOOL_MAP[self.F_pipeline.F_logits], F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], @@ -572,6 +583,7 @@ class FmhaFwdSplitKVKernel: bk1=self.F_tile.F_bk1, bk0max=self.F_tile.F_bk0max, vlayout=self.F_pipeline.F_vlayout, + logits=self.F_pipeline.F_logits, mask=self.F_pipeline.F_mask, bias=self.F_pipeline.F_bias, lse=self.F_pipeline.F_lse, @@ -669,26 +681,32 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -712,6 +730,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + # logits_soft_cap is only allowed if no bias + if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'): + continue k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -738,6 +759,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + cond &= pipeline.F_vlayout == 'row' + cond &= pipeline.F_squant == 'f' + if not cond: + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -796,6 +824,11 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis cond &= mode == "group" if not cond: continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue gen.append(k) return gen @@ -807,9 +840,10 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME file_path.write_text(api_pool.api) -def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> None: +def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) for kernel in kernels: @@ -819,9 +853,10 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, mask_impl) -> Non write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, mask_impl) -> None: +def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) + assert optdim_list == [-1] with file_path.open('a') as f: kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8f6fb8df54..bb1f495c4e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[]) "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" "note when squant=1, this value will be modified by range_q/k") + .insert("logits_soft_cap", "0", "attention logits soft capping value.") .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") @@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale_s == .0f) scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + const float logits_soft_cap = arg_parser.get_float("logits_soft_cap"); + std::string squant_str = arg_parser.get_str("squant"); bool squant = [&]() { if(squant_str == "auto") @@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else // fmha_fwd_traits or fmha_splitkv_traits { traits.is_group_mode = (mode == mode_enum::group); + traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; traits.has_lse = lse; @@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser) args.scale_p = scale_p; args.scale_o = scale_o; + args.logits_soft_cap = logits_soft_cap; + args.stride_bias = (bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias); args.stride_o = stride_o; @@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(0.f < logits_soft_cap) + { + ck_tile::reference_unary_elementwise( + s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) { + return ck_tile::type_convert( + logits_soft_cap * + std::tanhf(ck_tile::type_convert(logits / logits_soft_cap))); + }); + } + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 765c221a7b..1838ee5bd9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -143,6 +143,8 @@ struct fmha_fwd_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -232,6 +234,8 @@ struct fmha_fwd_splitkv_args float scale_p; float scale_o; + float logits_soft_cap; + ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; @@ -308,6 +312,85 @@ struct fmha_fwd_appendkv_args ck_tile::index_t batch_stride_vnew; }; +struct fmha_batch_prefill_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; // bias or alibi_slope pointer + void* rand_val_ptr; + void* lse_ptr; + void* o_ptr; + + // the real seqlen_q & seqlen_k are decided by following: + // batch mode (kvcache): + // seqlen_q = kargs.seqlen_q + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + // group mode (kvcache): + // seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b] + // seqlen_k = kargs.page_block_size * (kargs.kv_indptr[b + 1] - kargs.kv_indptr[b] - + // 1) + + // kargs.kv_last_page_lens[b] + const void* seqstart_q_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + // SGLang-style page table + int32_t num_total_pages; + void* kv_indptr; + void* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + void* kv_last_page_lens; + ck_tile::index_t page_block_size; +#endif + + float scale_s; + float scale_p; + float scale_o; + + float logits_soft_cap; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 + ck_tile::index_t stride_randval; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_randval; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_randval; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; + + float p_drop; + bool s_randval; + + std::variant, std::pair> + drop_seed_offset; +}; + template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { @@ -333,6 +416,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -371,6 +455,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.scale_s, args.scale_p, args.scale_o, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -443,6 +528,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.is_gappy, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -485,6 +571,7 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) args.cache_batch_idx, args.scale_s, args.scale_p, + args.logits_soft_cap, args.stride_q, args.stride_k, args.stride_v, @@ -618,6 +705,117 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args) return ck_tile::make_tuple(kargs, grids); } +template +auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_k, + args.batch_stride_v, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargsImpl(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.rand_val_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_total_pages, + args.kv_indptr, + args.kv_page_indices, +#if 0 // we assume page_block_size=1 for now + args.kv_last_page_lens, + args.page_block_size, +#endif + args.scale_s, + args.scale_p, + args.scale_o, + args.logits_soft_cap, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_randval, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_randval, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_randval, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type, + args.p_drop, + args.s_randval, + args.drop_seed_offset); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + // this is used to pattern-match internl kernel implementation, not to instantiate kernel template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -677,6 +877,7 @@ template ; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; @@ -776,6 +978,9 @@ struct fmha_fwd_appendkv_traits_ template float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args); +template +float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args); + // This is the public API, will be generated by script struct fmha_fwd_traits { @@ -784,6 +989,7 @@ struct fmha_fwd_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -800,6 +1006,7 @@ struct fmha_fwd_splitkv_traits std::string data_type; bool is_group_mode; bool is_v_rowmajor; + bool has_logits_soft_cap; mask_enum mask_type; bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; @@ -821,3 +1028,8 @@ struct fmha_fwd_appendkv_traits float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); + +using fmha_batch_prefill_traits = fmha_fwd_traits; +float fmha_batch_prefill(fmha_batch_prefill_traits, + fmha_batch_prefill_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0d35db14d4..c611618824 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -21,8 +21,7 @@ class HandlerId(IntEnum): ops = [] for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) - if full_module_name not in sys.modules: - ops.append(importer.find_spec(module_name).loader.load_module(module_name)) + ops.append(importer.find_spec(module_name).loader.load_module(module_name)) unwanted_prefix = 'fmha_' handlers = dict( [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, @@ -30,7 +29,7 @@ handlers = dict( ) assert 0 < len(handlers) -def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: if output_dir is None: output_dir = Path(__file__).parent else: @@ -40,10 +39,10 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.WRITE_BLOBS] - handler(output_dir, kernel_filter, receipt, mask_impl) + handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl) # list all the files that will be generated -def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], receipt, mask_impl) -> None: +def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None: assert output_file is not None file_path = Path(output_file) @@ -52,7 +51,7 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : for api, kernel_filter in zip(api_list, filters_list): handler = handlers[api][HandlerId.LIST_BLOBS] - handler(file_path, kernel_filter, receipt, mask_impl) + handler(file_path, kernel_filter, receipt, optdim_list, mask_impl) if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -109,16 +108,28 @@ if __name__ == "__main__": " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ + " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + ) + parser.add_argument( + "--optdim", + default='-1', + required=False, + help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \ + "eg. --optdim=32,64,128,256" ) args = parser.parse_args() api_list = args.direction.split(',') filter_list = args.filter.split(',') filter_list.extend([''] * (len(api_list) - len(filter_list))) + optdim_list = [int(hdim) for hdim in args.optdim.split(',')] + + if len(api_list) > 1: + assert optdim_list == [-1] if args.list_blobs is not None: - list_blobs(args.list_blobs, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: - write_blobs(args.output_dir, api_list, filter_list, int(args.receipt), mask_impl=args.mask) + write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 30cfee22f6..411db2e317 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,5 +1,9 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) -target_compile_options(tile_example_gemm_universal PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp old mode 100755 new mode 100644 index 69051423fb..1edb3da947 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -53,50 +53,67 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - constexpr dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - if(!Kernel::IsSupportedArgument(kargs)) + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + return Run(ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + return Run(ck_tile::integral_constant{}); } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 3254a407fd..25fab6bde0 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -55,17 +55,17 @@ struct GemmConfig #endif #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) // Compute friendly for Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64; + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 2; static constexpr ck_tile::index_t N_Warp = 2; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; static constexpr bool DoubleSmemBuffer = false; #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) @@ -93,7 +93,8 @@ struct GemmConfig static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; - static constexpr bool TransposeC = false; + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 6cb40e45d1..79ed9ce76b 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -55,7 +55,8 @@ void permute_tensor_b(Tensor& tensor) ALayout, BLayout, CLayout, - GemmConfig::TransposeC>; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; return ave_time; } @@ -240,8 +243,8 @@ int run_gemm_example_with_layouts(int argc, if(init_method == 0) { - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); } else if(init_method == 1) { @@ -250,8 +253,8 @@ int run_gemm_example_with_layouts(int argc, } else if(init_method == 2) { - ck_tile::FillConstant{static_cast(1)}(a_m_k); - ck_tile::FillConstant{static_cast(1)}(b_k_n); + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); } else { @@ -259,6 +262,11 @@ int run_gemm_example_with_layouts(int argc, b_k_n.SetZero(); } + if(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); @@ -394,7 +402,6 @@ int run_gemm_example_with_layouts(int argc, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index eef8d3b60e..b60a3b274b 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -12,6 +12,19 @@ #include "ck_tile/host.hpp" #include "gemm_utils.hpp" +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template ; + GemmConfig::TransposeC, + GemmConfig::UseStructuredSparsity>; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; @@ -60,10 +74,13 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -115,23 +133,40 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -142,76 +177,39 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& throw std::runtime_error(err.str()); } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) - // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail(ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); + #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -219,18 +217,18 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index e59fcaedad..ce689a370c 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -153,9 +153,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) local_expert_masking_dev.ToDevice(local_expert_masking_host.data()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts); + ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); - if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 109ec1b157..305cf118d2 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -153,7 +161,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi } } #else - if(moe_sorting_get_workspace_size(a.tokens, a.num_experts) != 0) + if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) { return moe_sorting_mp(t, a, s); } @@ -171,57 +179,107 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi return -1; } -#define MOE_SORTING_MP_0(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_1(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ }() -#define MOE_SORTING_MP_2(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ }() -#define MOE_SORTING_MP_3(unroll_num_, expert_masking_) \ - [&]() { \ - constexpr ck_tile::index_t unroll_num = unroll_num_; \ - constexpr bool expert_masking = expert_masking_; \ - using ms_problem = \ - ck_tile::MoeSortingProblemMp; \ - using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ - auto kargs = kernel::MakeKargs(a); \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(a); \ - return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ - }() +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) { @@ -230,29 +288,74 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co using ms_index_t = ck_tile::index_t; using ms_weight_type = float; - if(t.local_expert_masking) + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, true), - MOE_SORTING_MP_1(1, true), - MOE_SORTING_MP_2(1, true), - MOE_SORTING_MP_3(1, true)); - return ave_time; +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif } else { - float ave_time = ck_tile::launch_kernel(s, - MOE_SORTING_MP_0(1, false), - MOE_SORTING_MP_1(1, false), - MOE_SORTING_MP_2(1, false), - MOE_SORTING_MP_3(1, false)); - return ave_time; + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } } } return -1; } -int moe_sorting_get_workspace_size(int tokens, int num_experts) +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk) { - return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts); + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); } diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp index b47ae9013b..0fe8d81e70 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -22,6 +22,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs // if return non zero, means need workspace, you need to allocate a GPU buffer // and set to moe_sorting_args.p_ws // NOTE: workspace size are required to clear zero before use the API -int moe_sorting_get_workspace_size(int tokens, int num_experts); +int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk); float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh index cf2c2e164b..fbfb10822c 100644 --- a/example/ck_tile/13_moe_sorting/script/smoke_test.sh +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -26,3 +26,9 @@ $EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11 $EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19 $EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33 $EXE -t=11 -e=256 -k=5 -local_eid=99,110,129 +$EXE -t=128 -e=128 -k=6 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=32 -k=8 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=5 -moe_buf_size=163840 +$EXE -t=8192 -e=256 -k=8 -moe_buf_size=163840 +$EXE -t=163840 -e=256 -k=8 -moe_buf_size=163840 \ No newline at end of file diff --git a/example/ck_tile/15_fused_moe/fused_moe.hpp b/example/ck_tile/15_fused_moe/fused_moe.hpp index b354d1d347..46425384cc 100644 --- a/example/ck_tile/15_fused_moe/fused_moe.hpp +++ b/example/ck_tile/15_fused_moe/fused_moe.hpp @@ -56,4 +56,6 @@ struct fused_moe_traits bool local_expert_masking; // if mask experts as local expert }; +// if return zero, no ws needed +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/15_fused_moe/fused_moesorting.hpp b/example/ck_tile/15_fused_moe/fused_moesorting.hpp index a3ff8c5bf7..11e1c6e531 100644 --- a/example/ck_tile/15_fused_moe/fused_moesorting.hpp +++ b/example/ck_tile/15_fused_moe/fused_moesorting.hpp @@ -18,4 +18,5 @@ struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs { }; +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk); float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index f887d57aa9..b3515b1bec 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -2,6 +2,12 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "fused_moe.hpp" +#include "ck_tile/ops/fused_moe.hpp" + +int fused_moe_get_workspace_size(int tokens, int num_experts, int topk) +{ + return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk); +} float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s) { diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 7aedaa9317..0d83c48d02 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -7,6 +7,14 @@ #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_SUPPORT_LARGE_EXPERT 0 +#endif + +#ifndef MOE_SORTING_SUPPORT_LARGE_TOPK +#define MOE_SORTING_SUPPORT_LARGE_TOPK 0 +#endif + #if !MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ @@ -107,6 +115,10 @@ } #endif +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s); + float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) { if(t.weight_type == "fp32" && t.index_type == "int32") @@ -153,18 +165,198 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til } } #else - using index_t = ck_tile::index_t; - using ms_weight_type = float; - auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts); - auto sub_token_ = r_ - 2; - r_ = (r_ - 2) / 8; - bool is_sub_token_onshot = a.tokens <= sub_token_; + if(fused_moe_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0) + { + return fused_moesorting_mp(t, a, s); + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + auto sub_token_ = ck_tile::moe_sorting_get_sub_token(a.tokens, a.num_experts); + auto row_ = sub_token_ / 8; + bool is_sub_token_onshot = a.tokens <= sub_token_; bool is_local_expert_masking = t.local_expert_masking; - (void)c_; - MOE_SORTING_DISPATCH_EMASK_(r_); + MOE_SORTING_DISPATCH_EMASK_(row_); // MOE_SORTING_DISPATCH_ETILE(0, 0); #endif } return -1; } + +#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#if MOE_SORTING_SUPPORT_LARGE_EXPERT +#define MOE_SORTING_MP_2(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() + +#define MOE_SORTING_MP_3(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \ + }() +#endif + +#define MOE_SORTING_MP_23(mesh_type_, unroll_num_, expert_masking_) \ + [&]() { \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + constexpr bool expert_masking = expert_masking_; \ + using ms_problem = ck_tile::MoeSortingProblemMp; \ + using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_size = kernel::GetSmemSize(a); \ + return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \ + }() + +#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \ + if(t.local_expert_masking) \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true)); \ + return ave_time; \ + } \ + else \ + { \ + float ave_time = \ + ck_tile::launch_kernel(s, \ + MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false), \ + MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false), \ + MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false)); \ + return ave_time; \ + } + +float fused_moesorting_mp(fused_moesorting_trait t, + fused_moesorting_args a, + ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + using ms_index_t = ck_tile::index_t; + using ms_weight_type = float; + + if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) > + ck_tile::get_smem_capacity()) + { +#if MOE_SORTING_SUPPORT_LARGE_EXPERT + if(t.local_expert_masking) + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, true), + MOE_SORTING_MP_1(ms_index_t, 1, true), + MOE_SORTING_MP_2(ms_index_t, 1, true), + MOE_SORTING_MP_3(ms_index_t, 1, true)); + return ave_time; + } + else + { + float ave_time = ck_tile::launch_kernel(s, + MOE_SORTING_MP_0(ms_index_t, 1, false), + MOE_SORTING_MP_1(ms_index_t, 1, false), + MOE_SORTING_MP_2(ms_index_t, 1, false), + MOE_SORTING_MP_3(ms_index_t, 1, false)); + return ave_time; + } +#else + printf("do not support large expert %d\n", a.num_experts); + return -1; +#endif + } + else + { + ck_tile::index_t mesh_byte_size = + ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk); + if(mesh_byte_size == 1) + { + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16) + } + } + else if(mesh_byte_size == 2) + { +#if MOE_SORTING_SUPPORT_LARGE_TOPK + if(a.tokens * a.topk % 4 == 0) + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8) + } + else + { + MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8) + } +#else + printf("do not support large topk %d\n", a.topk); + return -1; +#endif + } + else + { + MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1) + } + } + } + return -1; +} diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index cb93ce8907..da843891ce 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -372,7 +372,8 @@ bool run(const ck_tile::ArgParser& arg_parser) num_sorted_tiles_host.get_element_space_size_in_bytes()); // if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr - ck_tile::index_t workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts); + ck_tile::index_t workspace_size = + ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk); ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0); if(workspace_size != 0) moe_sorting_ws.SetZero(); // note, clear here!!!! diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index a0cd18ec74..0219c67305 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -106,61 +106,81 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - if(s.log_level_ > 0) + else { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; }; if(has_hot_loop) @@ -168,18 +188,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -193,20 +213,21 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -214,7 +235,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -222,7 +244,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -230,7 +253,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -238,7 +262,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -246,20 +271,22 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } @@ -267,18 +294,18 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } std::ostringstream err; err << "Incorrect tail_num for pipeline without hotloop, expected Full, Odd or Even, but " diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 2a9903362d..9b134ff779 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -114,66 +114,86 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = GEMM_PIPELINE; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); + using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); - const dim3 grids = Kernel::GridSize(gemm_descs); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + constexpr dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(p_workspace_), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(p_workspace_), - gemm_descs.size())); - return ave_time; }; if(has_hot_loop) @@ -181,18 +201,18 @@ float grouped_gemm(const std::vector& gemm_descs, #if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3) if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Odd) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Even) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -206,20 +226,21 @@ float grouped_gemm(const std::vector& gemm_descs, // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } if constexpr(BaseGemmPipeline::PrefetchStages > 2) { if(tail_num == ck_tile::TailNumber::Two) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -227,7 +248,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -235,7 +257,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Four) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -243,7 +266,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Five) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -251,7 +275,8 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Six) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } @@ -259,20 +284,22 @@ float grouped_gemm(const std::vector& gemm_descs, { if(tail_num == ck_tile::TailNumber::Seven) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } } #elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4) if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } #endif } diff --git a/example/ck_tile/18_flatmm/CMakeLists.txt b/example/ck_tile/18_flatmm/CMakeLists.txt new file mode 100644 index 0000000000..f4d823e91a --- /dev/null +++ b/example/ck_tile/18_flatmm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp) + +set(EXAMPLE_FLATMM_COMPILE_OPTIONS) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter) +list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -DENABLE_FP8=1 -Wno-unused-local-typedef) +target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm/README.md b/example/ck_tile/18_flatmm/README.md new file mode 100644 index 0000000000..beaac785fc --- /dev/null +++ b/example/ck_tile/18_flatmm/README.md @@ -0,0 +1,35 @@ +# FLATMM Matrix Multiplication + +This folder contains example for FLATMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the flatmm calculation +make tile_example_flatmm_basic -j +``` +This will result in an executable `build/bin/tile_example_flatmm_basic` + +## example +``` +args: + -b batch size (default:1) + -m m dimension (default:1024) + -n n dimension (default:2048) + -k k dimension (default:64) + -a_layout Tensor A data layout (default: R) + -b_layout Tensor B data layout (default: R) + -c_layout Tensor C data layout (default: R) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2) + -e Absolute error tolerance (default:1e-5) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -warmup number of iterations before benchmark the kernel (default:10) + -repeat number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) +``` diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp new file mode 100644 index 0000000000..5f2c2a5aab --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "flatmm_basic.hpp" + +template +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s) +{ + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 2; + + // This part comes from the Codegen +#if defined(USING_MFMA_16x16x32) || defined(ENABLE_FP16) + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 128; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 4; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 16 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 64 : 16; + +#elif defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 128; + + constexpr ck_tile::index_t M_Warp = 1; + constexpr ck_tile::index_t N_Warp = 8; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type::value ? 32 : 32; + constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type::value ? 32 : 16; +#endif + using CodegenFlatmmShape = + ck_tile::TileFlatmmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; + using CodegenFlatmmPipeline = + ck_tile::FlatmmPipelineAGmemBGmemCRegV1; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::FlatmmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" + << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +#include "run_flatmm_example.inc" + +int run_flatmm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + if(data_type == "fp16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf16") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "fp8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else if(data_type == "bf8") + { + run_flatmm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data_type!"); + } + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } + return -1; +} + +int main(int argc, char* argv[]) { return !run_flatmm_example(argc, argv); } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.hpp b/example/ck_tile/18_flatmm/flatmm_basic.hpp new file mode 100644 index 0000000000..bbce978724 --- /dev/null +++ b/example/ck_tile/18_flatmm/flatmm_basic.hpp @@ -0,0 +1,136 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/flatmm.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE 1 +#define CK_TILE_PIPELINE_MEMORY 2 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE +#endif + +#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave +#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE) +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave +#else +#error "unsupported CK_TILE_PIPELINE_DEFAULT value" +#endif + +template +struct GemmBasicTypeConfig; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmBasicTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template +struct is_8bit_type + : std::bool_constant || std::is_same_v> +{ +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "256", "m dimension") + .insert("n", "256", "n dimension") + .insert("k", "128", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Row by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/18_flatmm/run_flatmm_example.inc b/example/ck_tile/18_flatmm/run_flatmm_example.inc new file mode 100644 index 0000000000..c191fff7d0 --- /dev/null +++ b/example/ck_tile/18_flatmm/run_flatmm_example.inc @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include + +template +constexpr const char* DataTypeToString() +{ + if constexpr(std::is_same_v) + { + return "fp16"; + } + else if constexpr(std::is_same_v) + { + return "fp8"; + } + else if constexpr(std::is_same_v) + { + return "bf8"; + } + else + { + return "unknown"; + } +} + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +// mfma_type, 0:32x32, 1:16x16 +template +auto shuffle_b(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8" || mfma_dtype == "bf8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + return t; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf, + ck_tile::DeviceMem& b_shuffle_dev_buf, + ck_tile::DeviceMem& c_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + ck_tile::FlatmmHostArgs args; + args.a_ptr = a_dev_buf.GetDeviceBuffer(); + args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer(); + args.c_ptr = c_dev_buf.GetDeviceBuffer(); + + args.k_batch = kbatch; + args.M = M; + args.N = N; + args.K = K; + args.stride_A = stride_A; + args.stride_B = stride_B; + args.stride_C = stride_C; + + float ave_time = + flatmm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString() + << " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A + << " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time + << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +int run_flatmm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_host( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_origin_host( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_rslt_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + // TODO: add different init types + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_origin_host); + + ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes()); + + a_dev_buf.ToDevice(a_host.data()); + c_rslt_host.SetZero(); + + // do pre-shuffle + std::string mfma = arg_parser.get_str("prec"); +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + ck_tile::index_t mfma_type = 1; +#else + ck_tile::index_t mfma_type = 0; +#endif + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_origin_host, mfma, mfma_type); + ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes()); + b_shuffle_dev_buf.ToDevice(b_shuffle_host.data()); + + invoke_flatmm( + a_dev_buf, + b_shuffle_dev_buf, + c_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); + + c_dev_buf.FromDevice(c_rslt_host.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_ref_host.SetZero(); + + ck_tile::reference_gemm( + a_host, b_origin_host, c_ref_host); + const float max_accumulated_value = + *std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes()); + b_origin_dev_buf.ToDevice(b_origin_host.data()); + + ck_tile::HostTensor c_gpu_ref_host( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes()); + c_gpu_ref_host.SetZero(); + c_gpu_ref_dev_buf.SetZero(); + + ADataType* d_A; + BDataType* d_B; + CDataType* d_C; + + ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType))); + ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType))); + ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType))); + + ck_tile::hip_check_error(hipMemcpy( + d_A, a_dev_buf.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice)); + ck_tile::hip_check_error(hipMemcpy(d_B, + b_origin_dev_buf.GetDeviceBuffer(), + N * K * sizeof(BDataType), + hipMemcpyHostToDevice)); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(), + d_C, + M * N * sizeof(CDataType), + hipMemcpyDeviceToHost)); + + ck_tile::hip_check_error(hipFree(d_A)); + ck_tile::hip_check_error(hipFree(d_B)); + ck_tile::hip_check_error(hipFree(d_C)); + + c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data()); + const float max_accumulated_value = + *std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_rslt_host, + c_gpu_ref_host, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} diff --git a/example/ck_tile/18_flatmm/script/smoke_test_basic.sh b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh new file mode 100755 index 0000000000..a3fc61cc31 --- /dev/null +++ b/example/ck_tile/18_flatmm/script/smoke_test_basic.sh @@ -0,0 +1,34 @@ +#!/bin/bash +EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)" +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=2 -warmup=0 -repeat=1' + +run_tests() { + for m in 128 1024; do + for n in 128 2048; do + for k in 128 4096; do + + $EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS + if [ $? -eq 0 ]; then + echo "Success: Test with m=$m, n=$n, k=$k executed successfully." + else + echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly." + # Optionally, exit or break if you need to halt further execution + # exit 1 + fi + + done + done + done +} + +set -x + +run_tests "bf16" +run_tests "fp16" + +set +x diff --git a/example/ck_tile/35_batched_transpose/README.md b/example/ck_tile/35_batched_transpose/README.md index d0583e7529..38bb2b32e4 100644 --- a/example/ck_tile/35_batched_transpose/README.md +++ b/example/ck_tile/35_batched_transpose/README.md @@ -24,4 +24,6 @@ args: -layout_out output tensor data layout - NHWC by default -seed seed to be used, -1 means random every time (default:-1) -k_name t to 1 will print kernel name (default:0) + -warmup warmup iterations to run this kernel (default:50) + -repeat number of iterations to run this kernel (default:100) ``` \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 77d768fe3f..1eb0445c84 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -1,7 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "batched_transpose_example.hpp" -#include template + ck_tile::index_t thread_y, + bool kPadM, + bool kPadN> float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) { - uint32_t dim_block_h = (a.height + block_y - 1) / block_y; - uint32_t dim_block_w = (a.width + block_x - 1) / block_x; - uint32_t dim_stride = a.height * a.width; + uint32_t dim_stride = a.height * a.width; a.dim_stride = dim_stride; - a.dim_block_h = dim_block_h; - a.dim_block_w = dim_block_w; + a.dim_block_h = block_y; + a.dim_block_w = block_x; using block_tile = ck_tile::sequence; using warp_tile = ck_tile::sequence; using thread_tile = ck_tile::sequence; using ts_problem = - ck_tile::BatchedTransposeProblem; + ck_tile::BatchedTransposeProblem; using ts_pipeline = ck_tile::BatchedTransposePipeline; using kernel = ck_tile::BatchedTransposeKernel; @@ -35,25 +34,40 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); + printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); + printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); + printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + kargs.batch, + kargs.height, + kargs.width, + kargs.dim_stride); + + printf("Launching Kernel...\n"); + float ave_time = ck_tile::launch_kernel( s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + printf("Kernel finished...\n"); + return ave_time; } // Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp16, ck_tile::fp16_t, 16, 16, 8, 8, 1, 1) \ - F(bf16, ck_tile::bf16_t, 16, 16, 8, 8, 1, 1) \ - F(fp32, ck_tile::fp32_t, 16, 16, 8, 8, 1, 1) \ - F(int8, ck_tile::int8_t, 16, 16, 8, 8, 1, 1) +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY) \ - static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -62,21 +76,38 @@ float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) { - if(t.type == "fp16") + if(t.type == "fp8") { - return transpose_fn_fp16_16_16_8_8_1_1(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + } + } + else if(t.type == "fp16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); + } } else if(t.type == "bf16") { - return transpose_fn_bf16_16_16_8_8_1_1(a, s); - } - else if(t.type == "fp32") - { - return transpose_fn_fp32_16_16_8_8_1_1(a, s); - } - else if(t.type == "int8") - { - return transpose_fn_int8_16_16_8_8_1_1(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + } } return -1; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 48fc2859bf..33b6f0eacf 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -21,13 +21,13 @@ void dump_host_tensor_4d(const ck_tile::HostTensor& x) std::cout << "["; for(size_t i = 0; i < len[0]; i++) { - std::cout << i << ": ["; + std::cout << "Batch " << i << ":" << std::endl; for(size_t j = 0; j < len[1]; j++) { - std::cout << j << ": ["; + std::cout << " Channel " << j << ":" << std::endl; for(size_t k = 0; k < len[2]; k++) { - std::cout << k << ": ["; + std::cout << " Row " << k << ": "; for(size_t v = 0; v < len[3]; v++) { if constexpr(std::is_same_v) @@ -41,15 +41,15 @@ void dump_host_tensor_4d(const ck_tile::HostTensor& x) } else { - std::cout << x(std::vector{i, j, k, v}) << " "; + std::cout << static_cast(x(std::vector{i, j, k, v})) + << " "; } } - std::cout << "]" << std::endl; + std::cout << std::endl; } - std::cout << "]" << std::endl; } - std::cout << std::endl; } + std::cout << "]" << std::endl; std::cout << "--------------------" << std::endl; } #endif @@ -93,12 +93,14 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("v", "1", "whether do CPU validation or not") .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") - .insert("N", "2", "input batch size. ") - .insert("C", "16", "input channel size.") - .insert("H", "1", "input height size.") - .insert("W", "16", "input width size. ") + .insert("N", "1", "input batch size. ") + .insert("C", "64", "input channel size.") + .insert("H", "18", "input height size.") + .insert("W", "64", "input width size. ") .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "t to 1 will print kernel name"); @@ -115,6 +117,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) int C = args.get_int("C"); int H = args.get_int("H"); int W = args.get_int("W"); + int n_warmup = args.get_int("warmup"); + int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); int seed = args.get_int("seed"); @@ -177,7 +181,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) return a_; }(); - ck_tile::stream_config sc{nullptr, true}; + ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat}; auto ms = batched_transpose(trait, karg, sc); @@ -202,7 +206,8 @@ bool run_batched_transpose(ck_tile::ArgParser args) layout_in.c_str(), ms); if(ms < 0) - printf("not supported\n"); + printf("------------------------------------not " + "supported-------------------------------------\n"); fflush(stdout); if(ms < 0) @@ -227,7 +232,9 @@ bool run_batched_transpose(ck_tile::ArgParser args) rtn &= ck_tile::check_err( y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); } - printf("valid:%s\n", rtn ? "y" : "n"); + printf("-----------------------------------------------------------------------valid:%s--------" + "--------------------------------------------------------------------\n", + rtn ? "y" : "n"); fflush(stdout); return rtn; } @@ -240,9 +247,9 @@ int main(int argc, char** argv) std::string prec = args.get_str("pr"); bool r = true; - if(prec.compare("fp32") == 0) + if(prec.compare("fp8") == 0) { - r &= run_batched_transpose(args); + r &= run_batched_transpose(args); } else if(prec.compare("fp16") == 0) { @@ -252,10 +259,6 @@ int main(int argc, char** argv) { r &= run_batched_transpose(args); } - else if(prec.compare("int8") == 0) - { - r &= run_batched_transpose(args); - } return r ? 0 : -1; } diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh new file mode 100755 index 0000000000..7ecfefc580 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +EXE=./build/bin/tile_example_batched_transpose + +for pr in "fp8" "fp16" "bf16"; do +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' + +done \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/script/run_full_test.sh b/example/ck_tile/35_batched_transpose/script/run_full_test.sh new file mode 100755 index 0000000000..4d0c988912 --- /dev/null +++ b/example/ck_tile/35_batched_transpose/script/run_full_test.sh @@ -0,0 +1,38 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the tile_example_batched_transpose executables in ../build/bin/ +# +# run the script as "./run_full_test.sh +# input arguments: +# environment tag : a string describing the specifics of your test environment +# branch name : name of the branch in git repo (git status | grep -e 'On branch') +# host name : $hostname +# gpu architecture: e.g., gfx90a, or gfx942, etc. + +#get the command line arguments: +export env_type=$1 +echo 'Environment type: ' $env_type +export branch=$2 +echo 'Branch name: ' $branch +export host_name=$3 +echo 'Host name: ' $host_name +export GPU_arch=$4 +echo 'GPU_arch: ' $GPU_arch + +function print_log_header(){ + rm -f $1; + echo 'On branch ' $3 &> $1; + echo 'Node name: ' $4 >> $1; + #get GPU_arch and number of compute units from rocminfo + echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1; + rocminfo | grep "Compute Unit:" >> $1; + hipcc --version | grep -e 'HIP version' >> $1; + echo 'Environment type: ' $2 >> $1; + /opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1; +} + +#run verification tests +example/ck_tile/35_batched_transpose/script/smoke_test.sh + +#run performance benchmarks + diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index fdfef2cea8..fdc01a2eb4 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -2,10 +2,26 @@ EXE=./build/bin/tile_example_batched_transpose -for pr in "fp32" "fp16" "int8" ; do +for pr in "fp8" "fp16" "bf16"; do $EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' $EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' $EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' $EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' -done +$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' + +done \ No newline at end of file diff --git a/example/ck_tile/36_copy/CMakeLists.txt b/example/ck_tile/36_copy/CMakeLists.txt new file mode 100644 index 0000000000..d1b9ba923c --- /dev/null +++ b/example/ck_tile/36_copy/CMakeLists.txt @@ -0,0 +1,4 @@ +add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp) +target_compile_options(test_copy_kernel PRIVATE + -mllvm -enable-noalias-to-md-conversion=0 +) \ No newline at end of file diff --git a/example/ck_tile/36_copy/README.md b/example/ck_tile/36_copy/README.md new file mode 100644 index 0000000000..7856f0b4bd --- /dev/null +++ b/example/ck_tile/36_copy/README.md @@ -0,0 +1,31 @@ +# Copy Kernel +This folder contains basic setup code designed to provide a platform for novice +CK_Tile kernel developers to test basic functionality with minimal additional +code compared to the functional code. Sample functional code for a simple +tile distribution for DRAM window and LDS window are provided and data is moved +from DRAM to registers, registers to LDS, LDS to registers and finally data +is moved to output DRAM window for a simple copy operation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture +# (for example gfx90a or gfx942) or leave it blank +sh ../script/cmake-ck-dev.sh ../ +# Make the copy kernel executable +make test_copy -j +``` +This will result in an executable `build/bin/test_copy_kernel` + +## example +``` +args: + -m input matrix rows. (default 64) + -n input matrix cols. (default 8) + -id warp to use for computation. (default 0) + -v validation flag to check device results. (default 1) + -prec datatype precision to use. (default fp16) + -warmup no. of warmup iterations. (default 50) + -repeat no. of iterations for kernel execution time. (default 100) +``` \ No newline at end of file diff --git a/example/ck_tile/36_copy/test_copy.cpp b/example/ck_tile/36_copy/test_copy.cpp new file mode 100644 index 0000000000..81ea5255fc --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.cpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include +#include "test_copy.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "64", "m dimension") + .insert("n", "8", "n dimension") + .insert("id", "0", "warp to use") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "50", "cold iter") + .insert("repeat", "100", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + using XDataType = DataType; + using YDataType = DataType; + + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t warp_id = arg_parser.get_int("id"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_ref({m, n}); + ck_tile::HostTensor y_host_dev({m, n}); + + // ck_tile::FillConstant{1.f}(x_host); + ck_tile::half_t value = 1; + for(int i = 0; i < m; i++) + { + value = 1; + for(int j = 0; j < n; j++) + { + x_host(i, j) = value++; + } + } + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + + using BlockWaves = ck_tile::sequence<2, 1>; + using BlockTile = ck_tile::sequence<64, 8>; + using WaveTile = ck_tile::sequence<64, 8>; + using Vector = ck_tile::sequence<1, 4>; + + ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); + std::cout << "grid size " << kGridSize << std::endl; + + using Shape = ck_tile::TileCopyShape; + using Problem = ck_tile::TileCopyProblem; + using Kernel = ck_tile::TileCopy; + + constexpr ck_tile::index_t kBlockSize = 128; + constexpr ck_tile::index_t kBlockPerCu = 1; + std::cout << "block size " << kBlockSize << std::endl; + std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl; + std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N + << std::endl; + std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " " + << BlockWaves::at(ck_tile::number<1>{}) << std::endl; + std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl; + + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), + m, + n, + warp_id)); + + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + // reference + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, x_host); + + std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + return run(arg_parser) ? 0 : -2; +} diff --git a/example/ck_tile/36_copy/test_copy.hpp b/example/ck_tile/36_copy/test_copy.hpp new file mode 100644 index 0000000000..8fed22a3d0 --- /dev/null +++ b/example/ck_tile/36_copy/test_copy.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +namespace ck_tile { + +template + typename BlockTile, // block size, seq + typename WaveTile, // warp size, seq + typename Vector> // contiguous elements(vector size) along seq +struct TileCopyShape +{ + // We split Workgroup waves into two specialized groups. + // One for reading data from global -> LDS, the other is doing reduction + static constexpr index_t WaveGroups = 2; + static constexpr index_t MWarps = BlockWaves::at(number<0>{}); + static constexpr index_t NWarps = BlockWaves::at(number<0>{}); + + static constexpr index_t Block_M = BlockTile::at(number<0>{}); + static constexpr index_t Block_N = BlockTile::at(number<1>{}); + + static constexpr index_t Warp_M = WaveTile::at(number<0>{}); + static constexpr index_t Warp_N = WaveTile::at(number<1>{}); + + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); + + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t WarpPerBlock_M = + integer_divide_ceil(BlockWaves::at(number<0>{}), WaveGroups); + static constexpr index_t WarpPerBlock_N = + integer_divide_ceil(BlockWaves::at(number<1>{}), WaveGroups); + + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{}); + + static constexpr index_t BlockSize = get_warp_size() * WaveNum; + static constexpr index_t WaveGroupSize = WaveNum / WaveGroups; + static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!"); +}; + +template +struct TileCopyProblem +{ + using XDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +template +struct TileCopy +{ + using Problem = ck_tile::remove_cvref_t; + using XDataType = typename Problem::XDataType; + + template + CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution() + { + using S = typename Problem::BlockShape; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest + // changing with given vector size. + constexpr index_t X1 = + S::Vector_N; // no. of elements along N dimensions to be read by each thread. + + constexpr index_t Y0 = + S::WaveNum / S::WaveGroups; // no. of active warps working in this thread block. + constexpr index_t Y1 = warp_size / X0; // no. of threads in a warp needed along M dimension. + constexpr index_t Y2 = + S::Warp_M / + (Y1 * + Y0); // no. of iterations each warp needs to perform to cover the entire tile window. + + constexpr auto outer_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}; + return make_static_tile_distribution(outer_encoding); + } + + CK_TILE_DEVICE void + operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const + { + using S = typename Problem::BlockShape; + + // LDS Data. + __shared__ XDataType x_lds[number{} * number{}]; + XDataType* __restrict__ p_x_lds = static_cast(x_lds); + + const auto x_lds_desc = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, 1), + number{}, + number<1>{}); + + auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_desc, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform( + make_tuple(number{} / S::Vector_N, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + auto x_lds_view = make_tensor_view(p_x_lds, x_lds_block_desc); + + auto x_block_lds_window = + make_tile_window(x_lds_view, + make_tuple(number{}, number{}), + {0, 0}, + MakeDRAMDistribution()); + auto x_block_lds_window_no_dist = make_tile_window( + x_lds_view, make_tuple(number{}, number{}), {0, 0}); + + // Input tensor + const auto iM = get_block_id() * S::Block_M; + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + auto x_block_window = + make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + MakeDRAMDistribution()); + + // Output tensor + const auto y_m = make_naive_tensor_view( + p_y, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + auto y_block_window = + make_tile_window(y_m, make_tuple(number{}, number{}), {iM, 0}); + + // Programming logic + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + auto my_id = get_warp_id(); + + auto DramTileDist = x_block_window.get_tile_distribution(); + using dram_reg_tile = decltype(make_static_distributed_tensor(DramTileDist)); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + dram_reg_tile dram_tile; + + if(my_id == warp_id) + { + // load from DRAM to registers + load_tile(dram_tile, x_block_window); + + // store in lds + store_tile(x_block_lds_window_no_dist, dram_tile); + + // read from lds to registers + load_tile(dram_tile, x_block_lds_window); + + // store from registers to DRAM + store_tile(y_block_window, dram_tile); + } + __syncthreads(); + move_tile_window(x_block_window, {0, S::Block_N}); + move_tile_window(y_block_window, {0, S::Block_N}); + } + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed35..d479cd35f6 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,6 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flatmm) add_subdirectory(35_batched_transpose) +add_subdirectory(36_copy) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5fa73d2fda..e38f166c1a 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -6,15 +6,10 @@ #include "ck/config.h" #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include "ck/utility/env.hpp" #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" #endif - -// environment variable to enable logging: -// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED -CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #endif // to do: add various levels of logging with CK_LOG_LEVEL @@ -70,7 +65,8 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define __gfx103__ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) @@ -129,7 +125,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx9__) // for GPU code +#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 @@ -248,6 +244,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) // workaround: compiler issue on gfx950 #define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1 +// workaround: compiler issue on gfx950 +#define CK_TEMP_DISABLE_FP4_TESTS 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1 + // denorm test fix, necessary for gfx90a #ifndef CK_GFX90A_DENORM_WORKAROUND #define CK_GFX90A_DENORM_WORKAROUND 0 diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 994e60025d..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2023 Advanced Micro Devices, Inc. + * Copyright (c) 2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -115,6 +115,10 @@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #endif +#ifndef CK_USE_WMMA_FP8 +#cmakedefine CK_USE_WMMA_FP8 @CK_USE_WMMA_FP8@ +#endif + #ifndef CK_USE_GFX94 #cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ #endif diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 3323ab6c7b..5439bbe1f0 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -86,7 +86,9 @@ inline bool is_gfx103_supported() inline bool is_gfx11_supported() { return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || + ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || + ck::get_device_name() == "gfx1152"; } inline bool is_gfx12_supported() diff --git a/include/ck/host_utility/flush_cache.hpp b/include/ck/host_utility/flush_cache.hpp index 918fb28ea9..08b3aba2b3 100644 --- a/include/ck/host_utility/flush_cache.hpp +++ b/include/ck/host_utility/flush_cache.hpp @@ -8,6 +8,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/stream_config.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/utility/flush_icache.hpp" diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index 5c1c1c4e60..11a1c9bbc0 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/stream_config.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 3336041354..35625d142e 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -94,7 +94,7 @@ struct FillMonotonicSeq template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = init_value_]() mutable { + std::generate(first, last, [=, *this, n = init_value_]() mutable { auto tmp = n; n += step_; return tmp; @@ -150,7 +150,7 @@ struct TransformIntoStructuralSparsity template void operator()(ForwardIter first, ForwardIter last) const { - std::for_each(first, last, [=, idx = 0](T& elem) mutable { + std::for_each(first, last, [=, *this, idx = 0](T& elem) mutable { auto tmp_idx = idx; idx += 1; return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index edf58b20b4..71417ce7bf 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) { os << ck::type_convert(v); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || + std::is_same_v) { const auto packed_floats = ck::type_convert(v); const ck::vector_type vector_of_floats{packed_floats}; @@ -252,7 +253,7 @@ struct ParallelTensorFunctor std::size_t iw_begin = it * work_per_thread; std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); - auto f = [=] { + auto f = [=, *this] { for(std::size_t iw = iw_begin; iw < iw_end; ++iw) { call_f_unpack_args(mF, GetNdIndices(iw)); @@ -359,7 +360,8 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return (mDesc.GetElementSpaceSize() + 1) / 2; } @@ -514,7 +516,8 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mDesc.GetOffsetFromMultiIndex(is...) / 2; } @@ -527,7 +530,8 @@ struct Tensor template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -540,7 +544,8 @@ struct Tensor template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -552,7 +557,8 @@ struct Tensor T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } @@ -564,7 +570,8 @@ struct Tensor const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 274051da83..785f74a3c0 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,6 +81,18 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4x2_pk_t operator()(Is...) + { + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{value, value})}; + } +}; + template <> struct GeneratorTensor_1 { @@ -209,6 +221,21 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = (std::rand() % (max_value - min_value)) + min_value; + float tmp1 = (std::rand() % (max_value - min_value)) + min_value; + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{tmp0, tmp1})}; + } +}; + template struct GeneratorTensor_3 { @@ -296,6 +323,25 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = float(std::rand()) / float(RAND_MAX); + float tmp1 = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp0 = min_value + tmp0 * (max_value - min_value); + float fp32_tmp1 = min_value + tmp1 * (max_value - min_value); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp new file mode 100644 index 0000000000..ebe075b55d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_mx_pipeline_base +{ + using ComputeTypeA = ADataType; + using ComputeTypeB = BDataType; + using AccType = float; // for now only support V_MFMA_SCALE_F32 + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + //> store rows/cols into thread registers in chunks of 16 + //> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] + static constexpr index_t KThreadChunk = 16; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPerInnerLoop = KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp new file mode 100644 index 0000000000..2fdabc6bc7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmWmmaops_pipeline_v3{}; + } + else + { + static_assert(false, "BlockGemmPipeline configuration is not available"); + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp new file mode 100644 index 0000000000..31c4729760 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 32; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerWmma * NPerWmma * KPerWmma); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + KPerWmma); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C WMMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d, %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_WMMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp new file mode 100644 index 0000000000..a63d32802e --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I5 = Number<5>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = 32; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + +#if defined(__gfx12__) + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; +#else + static constexpr index_t A_KRow = 1; + static constexpr index_t B_KRow = 1; +#endif + + static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + + static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); + static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t KRepeat = KPerBlock / KPack; + + static constexpr auto WmmaK = Number{}; + + using HotLoopInstList = + ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst; + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding WMMA and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && + BWmmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 && + NPerBlock % (NPerWmma * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1; + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWmma] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp new file mode 100644 index 0000000000..2fb95f0f8d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmWmmaops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v3 + : BlockwiseGemmWmmaops_pipeline_base +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Calculation of the number of instructions may require changes for WMMA + /* + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num; + + constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_wmma = + (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate; + constexpr auto num_dsread_b_wmma = + (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma); + constexpr auto num_wmma_per_issue = + num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA + }); + + // stage 2 + static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >= + ds_read_a_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_wmma - 1) * + ds_read_a_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + + static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >= + ds_read_b_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_wmma - 1) * + ds_read_b_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + */ + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp new file mode 100644 index 0000000000..29750b8baa --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + StaticallyIndexedArray{}> + b_thread_dequant_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I0)); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up + [mfma_reg_buf][Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I1)); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp new file mode 100644 index 0000000000..73749c6309 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,573 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // if constexpr(i.value < num_buffer_load_inst_a) { + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // } + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp index a94ef03008..074b5873ee 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -3,8 +3,10 @@ #pragma once +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" @@ -33,57 +35,112 @@ template + index_t KPack, + bool GUFusion = false> constexpr auto BlockGemmBPreshufflePipeline_Selector() { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { if constexpr(std::is_same::value) { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } } else { - return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< - BlkGemmPipeSche, - BlockSize, - ADataType, - BDataType, - ComputeDataType, - AccDataType, - ATileDesc, - BTileDesc, - AMmaTileDesc, - BMmaTileDesc, - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerXDL, - NPerXDL, - MRepeat, - NRepeat, - KPack>{}; + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } } } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index d751543175..1d27a74bd7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -280,12 +281,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -346,14 +349,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -409,14 +416,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -495,7 +506,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp index 4c019a41a4..7bbaaca5b6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -122,6 +122,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -281,12 +282,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(I0)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -318,14 +321,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_buf)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -389,14 +396,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -445,12 +456,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_reg), - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, I0), - a_thread_bufs(local_read_reg)); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); }); @@ -539,7 +553,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 6d115e7620..6f3a7e6357 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -5,6 +5,16 @@ #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" +#define DS_READ_A_PREFETCH_STAGES 2 + +template +constexpr auto compute_stage_loads(T total_loads, T stages) +{ + return std::make_pair((total_loads + stages - 1) / stages, // ceil + total_loads / stages // floor + ); +} + namespace ck { // Compute optimized pipeline @@ -123,6 +133,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack; + constexpr index_t K2 = KPack / KGroup; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat; + constexpr index_t K0 = KRepeat * KGroup; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -184,298 +191,132 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 - __device__ static constexpr auto HotLoopScheduler(Stage stage) + __device__ static constexpr auto HotLoopScheduler() { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; + constexpr auto num_total_stages = MRepeat; - static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / MRepeat; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / MRepeat; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto total_buffer_loads = num_buffer_load_inst_a + num_buffer_load_inst_b; + constexpr auto stages_available = MRepeat - DS_READ_A_PREFETCH_STAGES; + + constexpr auto stage_loads = compute_stage_loads(total_buffer_loads, stages_available); + + constexpr auto buffer_load_perstage_more = stage_loads.first; + constexpr auto buffer_load_perstage_less = stage_loads.second; + + constexpr auto buffer_load_stages_more = total_buffer_loads % stages_available; + + constexpr auto buffer_b_heavy_loads = buffer_load_perstage_more * buffer_load_stages_more; + constexpr auto buffer_b_remaining = + num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more; + + constexpr auto buffer_load_b_stages = + buffer_b_heavy_loads > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + buffer_b_remaining / buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - DS_READ_A_PREFETCH_STAGES - buffer_load_b_stages; + + static_assert(buffer_load_a_stages > 0, + "The buffer load a stages should always have a value over 0."); + + constexpr auto buffer_load_issue_point_interval_more = + math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_more); + constexpr auto buffer_load_issue_point_interval_less = + buffer_load_perstage_less == 0 + ? INT32_MAX + : math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_less); + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); } }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 2) - { - constexpr auto staged_num_mfma_per_buffer_load_a = - math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; - - // A global - static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - } - } - }); - - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - template - __device__ static constexpr auto EpilogueScheduler_1(Stage stage) - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; - constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - if constexpr(stage.value == 0) - { - constexpr auto staged_num_buffer_load_b_per_ds_read_a = - num_buffer_load_inst_b / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_buffer_load_b = - staged_num_mfma / num_buffer_load_inst_b; - // B global - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { - ignore = ibuf_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read - }); - - __builtin_amdgcn_sched_barrier(0); - } - else if constexpr(stage.value == 1) - { -#if 0 - constexpr auto staged_num_ds_write_a_per_ds_read_a = - num_ds_write_inst_a / staged_num_ds_read_inst_a; - constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; - // A local write - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - - static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { - ignore = idswrite_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - }); - - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); -#elif 1 - constexpr auto staged_num_mfma_per_ds_write_a = - math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); - - constexpr auto stage_more_mfma = - staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; - - // A local write - static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { - if constexpr(i_inst.value < stage_more_mfma) - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - else - { - if(i_inst.value < staged_num_ds_read_inst_a) - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - } - else - { - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write - } - } - }); -#endif - __builtin_amdgcn_sched_barrier(0); - } - else - { - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier( - 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read - }); - - __builtin_amdgcn_sched_barrier(0); - } - } - - __device__ static constexpr auto EpilogueScheduler_2() - { - constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; - - constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; - - constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; - constexpr auto staged_num_mfma = num_mfma / MRepeat; - - constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; - - // A local Read - static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { - ignore = i_inst; - __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA - __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read }); - __builtin_amdgcn_sched_barrier(0); + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); } template {}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(I0, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(I0, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + // K = k0 × KGroup × k1 = k0 × kg0 × A_K1 + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); }); // Initialize C @@ -558,26 +404,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(local_read_buf)); - b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); - } - else if constexpr(m0.value == 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); - } - else if constexpr(m0.value == 2) - { - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -613,49 +451,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(local_read_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(mfma_reg_buf), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % - 2>{}, - I0, - I0, - k0, - I0, - I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); } - - HotLoopScheduler(m0); }); + HotLoopScheduler(); }; LoopFunc(I0, I1); @@ -667,20 +544,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto m0) { - if constexpr(m0.value == 0) - { - b_blockwise_copy.Run(b_grid_desc, - b_grid_buf, - b_block_desc_n0_n1_k0_k1, - b_block_origin_idx, - b_thread_bufs(I1)); - } - else if constexpr(m0.value == MRepeat - 1) - { - a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); - } + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { vector_type a_thread_vec; @@ -707,36 +578,68 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } else { static_for<0, KRepeat, 1>{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); } - - EpilogueScheduler_1(m0); }); + HotLoopScheduler(); + static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -764,25 +667,29 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I1), - a_thread_desc_, - make_tuple( - Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); - // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle - // latency - // __builtin_amdgcn_sched_barrier(0); + + HotLoopScheduler(); } else if constexpr(TailNum == TailNumber::Odd) { @@ -813,18 +720,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(Number{}, I0, I0, k0, I0, I0), - a_block_buf.At(I0), - a_thread_desc_, - make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), - a_thread_buf); + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); }); - - EpilogueScheduler_2(); } }); } @@ -841,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3, + Sequence<1, 1, 1, 1, 1, KPack / KGroup>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index d7ba2559ea..6c1c5b1c4d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); static constexpr auto xdlops_gemm = XdlopsGemm{}; @@ -57,6 +58,11 @@ struct BlockwiseGemmXdlops_pipeline_base static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KPerInnerLoop = KPack; + static constexpr index_t KGroup = + ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); @@ -333,7 +339,7 @@ struct BlockwiseGemmXdlops_pipeline_base return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( c_grid_desc_g_m0_n0_m1_n1_m2_n2); } - + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp index 24f6afc381..c1433659d6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -7,6 +7,35 @@ namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} + template constexpr auto BlockGemmMXPipeline_Selector() { + + // Hardware MX GEMM pipeline if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { return BlockwiseGemmXdlops_pipeline_v1_mx - : BlockwiseGemmXdlops_pipeline_base + : BlockwiseGemmXdlops_mx_pipeline_base { - using Base = BlockwiseGemmXdlops_pipeline_base; + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; using Base::I0; using Base::I1; using Base::KRepeat; @@ -134,7 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; __host__ static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -172,45 +173,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( @@ -276,49 +238,31 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, @@ -326,15 +270,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -345,8 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx(); - // main body if constexpr(HasMainLoop) { @@ -363,141 +322,166 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // MFMA accumulation - // m = 1:MPerXDL - // n = 1:NPerXDL - // k = 1:KPack - // c(m,n) += a(m,k)*b(k,n) xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); - - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}( - [&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}( - [&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(m0, k0, g, r)); - - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); - }); + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); + // Prefetch a_scales static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + // Prefetch b_scales static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + + // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -507,7 +491,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); } } - // TODO: make this field protected when a_scale_thread_copy_ is moved here + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})); + make_tuple(Number{}, Number{}, Number{})); // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - // TODO: make this field protected when b_scale_thread_copy_ is moved here - static constexpr auto b_scale_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); protected: using Base::a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp index 859649185a..92aef65388 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadGroupTensorSliceTransfer_v4r1_gather @@ -58,7 +59,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather const DstDesc& dst_desc, const Index& dst_block_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : threadwise_transfer_(src_desc, make_zero_multi_index(), src_element_op, @@ -190,6 +191,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1_gather DstScalarStrideInVector, ThreadTransferSrcResetCoordinateAfterRun, ThreadTransferDstResetCoordinateAfterRun, + IndexType, GatherDim, NumThreadScratch>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index cf758e4d5f..bee0b01a74 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,6 +42,7 @@ template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); } } @@ -149,7 +149,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -169,10 +169,9 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } @@ -230,6 +229,7 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter DstScalarPerVector, ThreadTransferSrcResetCoordinateAfterRunFlags, ThreadTransferDstResetCoordinateAfterRunFlags, + IndexType, ScatterDim, OutputScatter, ScatterWeightIdx, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp index 2abf1d5a10..9c44bda5ca 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,7 +59,8 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator const std::array& input_right_pads, const AElementwiseOperation& a_element_op, const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op) = 0; + const CDEElementwiseOperation& cde_element_op, + const ck::index_t split_k = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 00518b369f..72c011bfb2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,7 +153,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_batch_offset, p_bs_grid + b_batch_offset, p_ds_grid_grp, @@ -439,7 +439,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -454,11 +454,26 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - using GridwiseGemm = - ck::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = ck::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index d53fbca4ea..fc1a2b995a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,19 +80,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -556,7 +557,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 25a9d7f96d..0cd1d84a43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -88,19 +88,20 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -344,7 +345,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 630f143260..12085edaae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -107,19 +107,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -336,7 +337,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index 30ae72a63e..de7d67f08b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 2662e5c360..bae5c6019d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 963f0edd08..7d9555dc82 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -224,12 +224,20 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale PermuteA, PermuteB>; + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + static constexpr index_t BPackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; else return 1; }(); + struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -352,10 +360,10 @@ struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - auto size_a_buffer = - a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); - auto size_b_buffer = - b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; ck::utility::RotatingMemWrapper rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index f0f89f1d1b..77974f84ae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -56,19 +56,20 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -324,7 +325,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 0b73317c5e..d4f89b3e09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 13eb23574f..a8eb73d730 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 28778d825b..6eb9281d30 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 7fa231d4f4..5fad21f521 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 3be7313d2b..c7aa54f1d9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 8aa20f7ad4..68ec8187a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef DEVICE_CONV3D_FWD_XDL_HPP #define DEVICE_CONV3D_FWD_XDL_HPP @@ -10,6 +10,7 @@ #include "device.hpp" #include "device_conv_fwd.hpp" #include "common_header.hpp" +#include "ck/utility/env.hpp" #include "tensor_layout.hpp" #include "convolution_forward_specialization.hpp" #include "tensor_descriptor.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 1edae33be3..ddabd61c3d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index de8f35a640..2881036bee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eb0fb55f5d..7faee161c1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 3fae3a3765..6c4195e75d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,19 +57,20 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -257,7 +258,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // TODO: Implement + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"<, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + /// @brief Helper structure responsible for kernel invocation. /// /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU @@ -278,10 +292,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 58a182b924..faa235be50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -130,6 +130,20 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + int GetPreShuffleParameters() override { return NPerXDL; } // Invoker @@ -168,10 +182,10 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp index 044350d11c..456e5e90d1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -139,6 +139,20 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + // Invoker struct Invoker : public BaseInvoker { @@ -174,10 +188,10 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale rotating_mem( arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 34df9a1d7b..2c34be9007 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -22,6 +22,7 @@ namespace ck { namespace tensor_operation { namespace device { +// clang-format off /** * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types * @@ -31,8 +32,8 @@ namespace device { * Assumptions: * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification * - Each scale applies to ScaleBlockSize elements in K direction - * - A scale matrix is row-major - * - B scale matrix is column-major + * - A scale matrix is a row-major + * - B scale matrix is a column-major * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the * exponent will be interpreted as conventional biased Float32 exponent (E8M0) * @@ -72,10 +73,10 @@ namespace device { * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ - * // MFMA accumulation for multirate instructions - * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){ - * for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){ - * // MFMA instruction + * // MFMA accumulation + * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){ + * // MFMA instruction + * for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){ * for(int m = mw; m < mw + MPerXDL; m++){ * for(int n = nw; n < nw + NPerXDL; n++){ * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ @@ -96,6 +97,7 @@ namespace device { * \endcode * */ +// clang-format on template || is_same_v || - is_same_v || is_same_v || - is_same_v)&&(is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v), + static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), "Only microscaling formats are supported for ADataType and BDataType"); static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); + static_assert(is_same_v && is_same_v, + "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType"); + return true; } static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_supported()) + if constexpr(!IsValidCompilationParameter()) + { + return false; + } + + if(ck::get_device_name() != "gfx950") { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index fd6f3b65f2..213501468a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp index c2a27ebbdb..7315fe75a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 1cf58fec25..884175eaca 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -64,7 +64,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( contraction_arg_ptr[group_id].p_a_grid_, contraction_arg_ptr[group_id].p_b_grid_, contraction_arg_ptr[group_id].p_ds_grid_, @@ -368,7 +368,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 359711e5c4..5e41c96dfc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) : p_a_grid_{static_cast(p_a)}, p_b_grid_{static_cast(p_b)}, p_ds_grid_{}, @@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} + input_right_pads_{input_right_pads}, + k_batch_{split_k} { // populate Ds pointer static_for<0, NumDTensor, 1>{}([&](auto i) { @@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle std::array conv_filter_strides_; std::array input_left_pads_; std::array input_right_pads_; + + const index_t k_batch_; }; // Invoker @@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { + if(arg.k_batch_ != 1) + { + return false; + } + // check device if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) { @@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) { return Argument{p_a, p_b, @@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle input_right_pads, a_element_op, b_element_op, - cde_element_op}; + cde_element_op, + split_k}; } static auto MakeInvoker() { return Invoker{}; } @@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) override + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override { return std::make_unique(p_a, p_b, @@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle input_right_pads, a_element_op, b_element_op, - cde_element_op); + cde_element_op, + split_k); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 770e531e44..f18ce40fc5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -8,18 +8,21 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" #include "ck/host_utility/io.hpp" namespace ck { @@ -67,7 +70,8 @@ template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum OutElementOp> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -88,12 +92,14 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -119,19 +125,22 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + KBatch, + k_idx); #else ignore = p_a_grid; ignore = p_b_grid; @@ -232,7 +241,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = + (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && + std::is_same_v, element_wise::PassThrough>; // TODO: Add support for different A and B data types. using ABDataType = ADataType; @@ -314,53 +327,29 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); } - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ABDataType, - ABDataType, - AComputeType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOp, - BElementwiseOp, - CDEElementwiseOp, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched, - PipelineVersion::v1, - BComputeType>; +// GridwiseGemm +#define GridwiseGemmMultiDTemplateParams \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + + template + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) + { + return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + } template static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) @@ -389,15 +378,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{})); + decltype(GridwiseGemmMultipleD_xdl_cshuffle:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{})); using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - EGridDesc_M_N{})); + decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; + using Block2ETileMap = remove_cvref_t< + decltype(GridwiseGemmMultipleD_xdl_cshuffle< + GridwiseGemmMultiDTemplateParams>::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>; using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -510,7 +499,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) : p_a_grid_{static_cast(p_a)}, p_b_grid_{static_cast(p_b)}, p_ds_grid_{}, @@ -524,7 +514,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} + input_right_pads_{input_right_pads}, + k_batch_{split_k} { std::array a_g_n_k_wos_strides_transposed = conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, @@ -625,7 +616,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 conv_filter_dilations, input_left_pads, input_right_pads, - tildes}; + tildes, + k_batch_}; conv_N_per_block_ = conv_to_gemm_transform_.N_; @@ -700,15 +692,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, - block_2_etile_map)) + block_2_etile_map, + k_batch_)) { ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm:: MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n)); e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n)); } } @@ -843,7 +837,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // pointers const ADataType* p_a_grid_; const BDataType* p_b_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; + typename GridwiseGemmMultipleD_xdl_cshuffle::DsGridPointer + p_ds_grid_; EDataType* p_e_grid_; // tensor descriptor for problem definition @@ -890,6 +885,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::array input_left_pads_; std::array input_right_pads_; + const index_t k_batch_; index_t num_workgroups_per_Conv_N_; }; @@ -898,12 +894,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; - float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; const index_t gdy = arg.num_group_; - const index_t gdz = arg.num_workgroups_per_Conv_N_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -932,7 +929,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { throw std::runtime_error("wrong! device_op has invalid setting"); } @@ -960,7 +958,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 Block2ETileMap, ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch, - has_main_loop>; + has_main_loop, + ElementOp>; return launch_and_time_kernel( stream_config, @@ -981,10 +980,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.block_2_etile_map_container_[i], arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_); + arg.compute_ptr_offset_of_n_, + arg.k_batch_); }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_)) { ave_time += launch_kernel(integral_constant{}); } @@ -1083,7 +1083,19 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, std::array{0}); } - ave_time += RunGemm(arg, stream_config); + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + // Transpose from NHWGC to NGCHW if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) @@ -1147,6 +1159,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return false; } + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.k_batch_ > 1) + { + return false; + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ != 1) + { + return false; + } + } + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; @@ -1257,7 +1283,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { return false; } @@ -1334,7 +1361,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) { return Argument{p_a, p_b, @@ -1354,7 +1382,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads, a_element_op, b_element_op, - cde_element_op}; + cde_element_op, + split_k}; } static auto MakeInvoker() { return Invoker{}; } @@ -1380,7 +1409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const std::array& input_right_pads, const AElementwiseOp& a_element_op, const BElementwiseOp& b_element_op, - const CDEElementwiseOp& cde_element_op) override + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override { return std::make_unique(p_a, p_b, @@ -1400,7 +1430,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 input_right_pads, a_element_op, b_element_op, - cde_element_op); + cde_element_op, + split_k); } std::unique_ptr MakeInvokerPointer() override diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index 4d730b1f37..c7d95254c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -8,7 +8,7 @@ #include #include "ck/utility/common_header.hpp" - +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -19,7 +19,7 @@ #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index f40b238c8a..c904b4e7d5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f6ea23a1e7..dd5b97096d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -17,7 +17,7 @@ #include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" @@ -81,6 +81,11 @@ __global__ void k_idx); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; #endif // end of if (defined(__gfx9__) } @@ -140,6 +145,11 @@ __global__ void k_idx); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; #endif // end of if (defined(__gfx9__) } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 272b832e11..27da1d91a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -178,7 +179,7 @@ __global__ void const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_group_offset + a_n_offset, p_bs_grid + b_group_offset, p_ds_grid_grp, @@ -433,7 +434,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -449,11 +450,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm - using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index b2f1dbfa5c..bebcd72ceb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -11,6 +11,7 @@ #include "ck/library/utility/numeric.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -278,9 +279,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiD = DsDataType::Size() > 0; static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; - // multi ABD not supported - static_assert(!isMultiABD, "Multi A, Mutli B and Multi D are not supported"); - static constexpr index_t NumATensor = GetNumABTensors(); static constexpr index_t NumBTensor = GetNumABTensors(); static constexpr index_t NumDTensor = DsDataType::Size(); @@ -1079,91 +1077,96 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) + if constexpr(!isMultiABD) { - const index_t a_grid_size = - arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( - arg.a_in_transpose_desc_); - const index_t b_grid_size = - arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( - arg.b_in_transpose_desc_); + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t b_grid_size = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); - ADataType* p_a_out_grid = type_convert(arg.p_workspace_); - BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - auto kernel_transpose = kernel_elementwise_dual, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - Block2TileMapElementwise, - element_wise::PassThrough>; + auto kernel_transpose = + kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; - avg_time += launch_and_time_kernel(stream_config, - kernel_transpose, - dim3(a_grid_size + b_grid_size), - dim3(ElementwiseBlocksize), - 0, - make_tuple(arg.a_in_transpose_desc_), - make_tuple(arg.b_in_transpose_desc_), - make_tuple(arg.a_out_transpose_desc_), - make_tuple(arg.b_out_transpose_desc_), - make_tuple(arg.p_a_grid_), - make_tuple(arg.p_b_grid_), - make_tuple(p_a_out_grid), - make_tuple(p_b_out_grid), - arg.elementwise_block_2_ctile_map_transpose_a_, - arg.elementwise_block_2_ctile_map_transpose_b_, - element_wise::PassThrough{}, - a_grid_size); + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } } - - avg_time += RunGemm(arg, stream_config); - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - const index_t grid_size = - arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( - arg.e_in_transpose_desc_); - - const EDataType* p_e_in_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - - EDataType* p_e_out_grid = arg.p_e_grid_; - - auto kernel_transpose = kernel_elementwise, - ck::Tuple, - ck::Tuple, - ck::Tuple, - Block2TileMapElementwise, - element_wise::PassThrough>; - - avg_time += launch_and_time_kernel(stream_config, - kernel_transpose, - dim3(grid_size), - dim3(ElementwiseBlocksize), - 0, - make_tuple(arg.e_in_transpose_desc_), - make_tuple(arg.e_out_transpose_desc_), - make_tuple(p_e_in_grid), - make_tuple(p_e_out_grid), - arg.elementwise_block_2_ctile_map_transpose_e_, - element_wise::PassThrough{}); - } - return avg_time; } @@ -1181,6 +1184,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t G = arg.b_g_k_c_xs_lengths_[I0]; const index_t K = arg.b_g_k_c_xs_lengths_[I1]; const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(isMultiABD) + { + return false; + } // check device if(get_device_name() == "gfx908") diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index b2903121b1..94a4e0da4c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -89,7 +89,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, Tuple<>{}, @@ -192,7 +192,6 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t MaxGemmsNum = 32; - static_assert(NumDTensor == 0, "MultiD not supported."); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -351,16 +350,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor #define GridwiseGemmTemplateParameters \ ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType @@ -440,89 +438,94 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - // Perform grouped gemm, generate array of tranformer for convolution - Array conv_to_gemm_transformer_arr; - Array a_grid_ptrs; - Array c_grid_ptrs; - - ck::tie(conv_to_gemm_transformer_arr, - a_grid_ptrs, - c_grid_ptrs, - gemms_count_, - is_split_valid_) = - GenerateConvToGemmTransforms( - ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, - a_g_n_c_wis_strides_, - b_g_k_c_xs_lengths_, - b_g_k_c_xs_strides_, - e_g_n_k_wos_lengths_, - e_g_n_k_wos_strides_, - conv_filter_strides_, - conv_filter_dilations_, - input_left_pads_, - input_right_pads_}, - static_cast(p_a), - static_cast(p_e)); - - grid_size_ = 0; - valid_gemms_count_ = 0; - - if(is_split_valid_) + if constexpr(NumDTensor == 0) { - // Create GemmArg for each gemm(conv) - for(index_t i = 0; i < gemms_count_; i++) + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array c_grid_ptrs; + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) { - const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( - conv_to_gemm_transformer_arr[i])}; - const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( - conv_to_gemm_transformer_arr[i])}; - const auto e_grid_desc_m_n = - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); - - const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - const index_t grid_size_grp = - block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); - - const index_t BlockStart = grid_size_; - const index_t BlockEnd = grid_size_ + grid_size_grp; - - grid_size_ += grid_size_grp; - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - Tuple<>{}, - e_grid_desc_m_n, - block_2_etile_map)) + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) { + const AGridDesc_M_K a_grid_desc_m_k{ + DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{ + DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( + conv_to_gemm_transformer_arr[i]); - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - valid_gemms_count_++; + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + Tuple<>{}, + e_grid_desc_m_n, + block_2_etile_map)) + { + + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } } + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } - // N is the same for all convs - conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; } - - // Strides for G and N remain the same - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; - compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; } void Print() const @@ -578,55 +581,63 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if(stream_config.log_level_ > 0) + if constexpr(NumDTensor == 0) { - arg.Print(); - } + if(stream_config.log_level_ > 0) + { + arg.Print(); + } - const index_t num_workgroups_per_Conv_N = - arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; - const index_t gdx = arg.grid_size_; - const index_t gdy = arg.num_group_; - const index_t gdz = num_workgroups_per_Conv_N; + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; - // K is constant for all gemms - const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * - arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< - GridwiseGemm, - MaxGemmsNum, - GemmArgs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; - return launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg.gemm_desc_kernel_args_, - arg.gemms_count_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); - }; + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } } else { - return launch_kernel(integral_constant{}); + return 0.f; } } @@ -643,6 +654,12 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(NumDTensor != 0) + { + return false; + } // Check if all descs are valid if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_)) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index c148d7dbb7..10d8a4a44d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -607,6 +608,9 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemmSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index d692aa05ce..18872e38ea 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -8,6 +8,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/hip_check_error.hpp" diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 2a6406aac3..cbee4e09f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -8,6 +8,7 @@ #include #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -64,7 +65,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_ptr[group_id].a_ptr_, gemm_desc_ptr[group_id].b_ptr_, gemm_desc_ptr[group_id].ds_ptr_, @@ -241,7 +242,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 03431d7156..01f52881f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -7,6 +7,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/hip_check_error.hpp" @@ -423,6 +424,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitKSetWorkSpacePointer(p_arg, p_dev_kernel_args); } - void SetHostKernelArgs(BaseArgument* p_arg, void* p_host_kernel_args) const + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { Argument* pArg_ = dynamic_cast(p_arg); if(!pArg_) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 950fe0236d..08d177035e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -65,8 +65,12 @@ template , pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + int GetPreShuffleParameters() override { return NPerXDL; } // Invoker @@ -179,10 +202,10 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -265,7 +287,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -280,7 +301,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -290,7 +310,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } @@ -310,7 +329,6 @@ struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle; RunKernel(kernel); } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 530876650e..0e58d5acb4 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -404,6 +404,14 @@ struct AddRelu y = a > type_convert(0.0f) ? a : type_convert(0.0f); }; + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float a = type_convert(x0) + type_convert(x1); + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + template <> __host__ __device__ constexpr void operator()(int& y, const int& x0, const int8_t& x1) const diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index f602e36e73..672998d811 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -357,6 +357,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const int32_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 64fad1ca48..311545aad6 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -1438,6 +1438,7 @@ struct BlockToCTileMap_GemmStreamK_v2 __host__ __device__ BlockToCTileMap_GemmStreamK_v2( uint32_t m, uint32_t n, uint32_t k, uint32_t grid_size = 1, uint32_t streamk_sel = 1) { + // total output tiles uint32_t num_tiles = math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); @@ -1445,6 +1446,9 @@ struct BlockToCTileMap_GemmStreamK_v2 uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + // Ensure grid_size is at least 1 to avoid division by zero + grid_size = math::max(grid_size, 1u); + // default to regular DP GEMM if sk blocks == 0 if(streamk_sel == 0) { @@ -1460,31 +1464,45 @@ struct BlockToCTileMap_GemmStreamK_v2 // 2-tile sk + DP GEMM else { - // check if there's enough work for DP+ stream-k bool bigEnough = num_tiles > grid_size; - // select between stream-k strategies + + // Select between stream-k strategies + // Add safety checks to prevent zero or negative values uint32_t sk_tiles = 0; if(streamk_sel == 1) // 1 tile stream-k { sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 + sk_tiles = math::max(sk_tiles, 1u); } else if(streamk_sel == 2) // 2-tile stream-k { sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 3) // 3-tile stream-k { sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } else if(streamk_sel == 4) // 4-tile stream-k { sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); } + sk_num_blocks = sk_tiles; - // remaining tiles are DP tiles + // Remaining tiles are DP tiles dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; sk_total_iters = k_iters_per_tile.get() * sk_tiles; @@ -1500,24 +1518,51 @@ struct BlockToCTileMap_GemmStreamK_v2 // => sk_blocks * m + b = sk_total_iters // => b = sk_total_iters - m * sk_blocks // NOTE: big could be zero - uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; - sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; - k_iters_per_big_block = k_iters_per_sk_block + 1; + + // Add safety check for sk_num_blocks to prevent division by zero + if(sk_num_blocks > 0) + { + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + } + else + { + // Fallback to default GEMM if no stream-k blocks + sk_num_blocks = 0; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + dp_tiles = num_tiles; + dp_num_blocks = num_tiles; + dp_start_block_idx = 0; + sk_total_iters = 0; + } dp_num_blocks = dp_tiles; dp_start_block_idx = sk_num_blocks; } n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); - // using multiple blocks for parallel reduction + // Using multiple blocks for parallel reduction reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) { - uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); - uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); - equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); - equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + // Add additional safety checks + if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = + math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + else + { + // Default safe values + equiv_tiles_big = MDiv(1); + equiv_tiles_little = MDiv(1); + } } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index d54a00eaa2..a3301dd932 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -39,7 +39,6 @@ template ( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -550,6 +554,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return; } + const index_t num_k_per_block = + __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -591,7 +598,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -688,7 +695,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + (KPerBlock * k_batch)); gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -943,6 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp new file mode 100644 index 0000000000..4dfa472103 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,1725 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With universal GEMM +/// there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct GridwiseGemm_wmma_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_N_K1 -> K0_MNRepeat_MNWaves_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp similarity index 99% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 4f5fedcd83..d37b3cd38e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index cc8ae1806a..e5e32a8535 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 9f6d85dd78..29150c0688 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index ffa01efe17..a22fc06a50 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index d53e0c85a7..4f3250162e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index b805f600d5..ac3e821340 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 238ab14606..c0d9464136 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -167,11 +167,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle using mfma_selector = MfmaSelector; static constexpr index_t KPack = math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; static constexpr index_t KLane = mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); - static constexpr index_t KRepeat = KPerBlock / KLane / KPack; - static constexpr index_t NLane = NPerXdl; - static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KPackPerGroup = KPack / KGroup; + static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; static constexpr auto MakeDsGridPointer() { @@ -209,7 +211,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } __host__ __device__ static auto CalculateBK0Shuffled(index_t K) { - return math::integer_divide_ceil(K, KLane * KPack); + return math::integer_divide_ceil(K, KLane * KPackPerGroup); } __host__ __device__ static auto CalculateKPadded(index_t K) @@ -351,7 +353,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) { - constexpr index_t NkSwizzleNumber = Number{}; + constexpr index_t NkSwizzleNumber = Number{}; return make_naive_tensor_descriptor( make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); @@ -1228,7 +1230,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds @@ -1668,7 +1670,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, 0, - KPack * (get_thread_local_1d_id() % warpSize))); + KPackPerGroup * (get_thread_local_1d_id() % warpSize))); // LDS allocation for A and B: be careful of alignment // Cast after lds diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 715fcbcfef..f877912329 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -13,6 +13,7 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" namespace ck { @@ -34,7 +35,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -65,7 +66,7 @@ __global__ void // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -158,16 +159,22 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static constexpr auto AK1Number = Number{}; static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) - ? true - : false; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; @@ -790,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl; + constexpr auto KThreadRead = WaveSize / MPerXdl; constexpr auto K0PerThreadRead = AK0Number / KThreadRead; constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) @@ -922,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 } else // RowMajor B { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl; + constexpr auto KThreadRead = WaveSize / NPerXdl; constexpr auto K0PerThreadRead = BK0Number / KThreadRead; constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) @@ -1087,10 +1096,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(KPerBlock % ScaleBlockSize == 0, "KPerBlock should be multiple of ScaleBlockSize"); - static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat, - "Single call to xdlops_gemm::Run should process exactly ScaleBlockSize " - "elements in k dimension"); - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1475,61 +1480,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - static constexpr auto KPerThread = KPerBlock / K0PerXdlops; - - // NXdlPerWave == NRepeat - // MXdlPerWave == MRepeat - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - - // Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2 + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - auto a_thread_offset_m = - MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) + - mfma.selected_mfma.group_size * - ((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl); - auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl; + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + - (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; - auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl; + // TODO: Document initial thread mapping for more combinations of parameters - auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - AScaleDataType, - AScaleDataType, - decltype(a_scale_grid_desc_am_ak), // SrcDesc - decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc - Sequence, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 0, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>(a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, - a_thread_offset_k / ScaleBlockSize)); + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - BScaleDataType, - BScaleDataType, - decltype(b_scale_grid_desc_bn_ak), - decltype(BlockwiseGemmPipe::b_scale_thread_desc), - Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector - 1, - false>(b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, - b_thread_offset_k / ScaleBlockSize)); + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + mfma.selected_mfma.num_threads_per_blk; + + auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 6ee279a3f1..256b495c6e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -1,9 +1,10 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 1924c27b2b..255fb8cff4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" @@ -26,11 +26,17 @@ namespace ck { // two lds chunks. // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds // buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -44,7 +50,7 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm::template Run( + GridwiseGemm::template Run( karg.p_sorted_token_ids, karg.p_sorted_expert_ids, karg.p_max_token_id, @@ -66,7 +72,6 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -81,21 +86,20 @@ __global__ void auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - GridwiseGemm:: - template Run_2Lds( - karg.p_sorted_token_ids, - karg.p_sorted_expert_ids, - karg.p_max_token_id, - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_ds_grid, - karg.p_c_grid, - p_shared, - p_shared1, - karg, - karg.a_element_op, - karg.b_element_op, - karg.c_element_op); + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); #else ignore = karg; #endif // end of if (defined(__gfx9__)) @@ -147,7 +151,12 @@ template ) @@ -490,8 +500,8 @@ struct GridwiseMoeGemm } template - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) @@ -902,7 +912,8 @@ struct GridwiseMoeGemm NPerXdl, MXdlPerWave, NXdlPerWave, - KPack>())>; + KPack, + IsInputGemm>())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1134,7 +1145,6 @@ struct GridwiseMoeGemm template __device__ static void Run(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, @@ -1195,6 +1205,7 @@ struct GridwiseMoeGemm return {blockIdx.x, blockIdx.y}; } }(); + const index_t block_n_id = block_mn.first; const index_t block_m_id = block_mn.second; const index_t token0 = @@ -1210,7 +1221,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray gather_offsets; + StaticallyIndexedArray gather_offsets; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; index_t token_offset = fused_token & 0xffffff; @@ -1218,9 +1229,10 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); - const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); // N0, K0, Blocksize*KPack const index_t n_block_data_idx_on_grid = @@ -1231,7 +1243,6 @@ struct GridwiseMoeGemm const auto b_grid_buf = make_dynamic_buffer( p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); - // A matrix in LDS memory, dst of blockwise copy constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); @@ -1261,6 +1272,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1303,24 +1315,74 @@ struct GridwiseMoeGemm static_assert(std::is_default_constructible_v); auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bpreshuffled, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } // shuffle C and write out { @@ -1348,6 +1410,185 @@ struct GridwiseMoeGemm constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[token_offset] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1445,17 +1686,8 @@ struct GridwiseMoeGemm const auto ds_grid_buf = generate_tuple( [&](auto i) { - using DDataType = remove_cvref_t>; - const DDataType* ptr_ = p_ds_grid[i]; - // hack logic here to support different kind of strides. todo fix it. - // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } return make_dynamic_buffer( - ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); }, Number{}); @@ -1492,7 +1724,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1518,7 +1750,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -1530,7 +1763,6 @@ struct GridwiseMoeGemm auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, @@ -1560,32 +1792,21 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray scatter_offsets; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; + IndexType token_offset = fused_token & 0xffffff; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -1593,7 +1814,7 @@ struct GridwiseMoeGemm // each thread write its data from VGPR to LDS c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, + c_thread_buf_fp32, c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, c_shuffle_block_buf); @@ -1606,8 +1827,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { @@ -1632,7 +1852,6 @@ struct GridwiseMoeGemm template __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, const index_t* p_sorted_expert_ids, @@ -1709,7 +1928,7 @@ struct GridwiseMoeGemm if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) return; - StaticallyIndexedArray + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; static_for<0, AMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[token_pos + m0]; @@ -1718,7 +1937,7 @@ struct GridwiseMoeGemm { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - gather_offsets(m0) = token_offset * problem.K; + gather_offsets(m0) = static_cast(token_offset) * problem.K; }); const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); @@ -1761,6 +1980,7 @@ struct GridwiseMoeGemm 1, AThreadTransferSrcResetCoordinateAfterRun, true, + IndexType, 1, 2>(a_grid_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -1955,11 +2175,12 @@ struct GridwiseMoeGemm const DDataType* ptr_ = p_ds_grid[i]; // hack logic here to support different kind of strides. todo fix it. // ascale t, 1; bscale E, N, 1, move ptr to E - if(i.value == 1) - { - ptr_ += - expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : 1); - } + // if(i.value == 1) + // { + // ptr_ += + // expert_id * (problem.StrideDs[1] ? problem.StrideDs[1] * problem.N : + // 1); + // } return make_dynamic_buffer( ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); }, @@ -1998,7 +2219,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2024,7 +2245,8 @@ struct GridwiseMoeGemm Sequence, uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, 1, // ScatterDim true, // OutputScatter: false, only use scatter weights scatter_weight_idx // ScatterWeightIdx: ascale @@ -2066,12 +2288,9 @@ struct GridwiseMoeGemm constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const float* p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS - StaticallyIndexedArray - scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk + StaticallyIndexedArray scatter_offsets; auto dstidx = sfc_cde_block.GetIndex(access_id); const index_t c_token_pos = @@ -2079,20 +2298,11 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = token_offset < problem.NumTokens - ? p_sorted_weights_0[token_offset * problem.StrideDs[0]] - : 0.0; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - else - { - const float* p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; + scatter_offsets(m0) = static_cast(token_offset) * problem.N; }); block_sync_lds(); @@ -2113,8 +2323,7 @@ struct GridwiseMoeGemm c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), tie(c_grid_buf), - scatter_offsets, - scatter_weights); + scatter_offsets); if constexpr(access_id < num_access - 1) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index d2fffcc591..753589f303 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -211,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source * tensor. * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. - * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source - * tensor. + * @tparam SrcScalarStrideInVector Not used. * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run * or rolled back one step in MoveSrcSliceWindow * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index bb9a452761..bd6fe772e4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -41,6 +41,7 @@ template struct ThreadwiseTensorSliceTransfer_v3r1_gather @@ -88,7 +89,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather const DstDesc& dst_desc, const Index& dst_slice_origin, const DstElementwiseOperation& dst_element_op, - const StaticallyIndexedArray& gather_offsets) + const StaticallyIndexedArray& gather_offsets) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), src_element_op_(src_element_op), @@ -221,7 +222,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - const index_t ld_offset = src_coord_.GetOffset() + gather_offset; + const IndexType ld_offset = src_coord_.GetOffset() + gather_offset; src_oob_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, true); @@ -935,7 +936,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather DstCoord dst_coord_; const SrcElementwiseOperation src_element_op_; const DstElementwiseOperation dst_element_op_; - StaticallyIndexedArray gather_offsets_; + StaticallyIndexedArray gather_offsets_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 6a1c195dc1..7cd0a0fc7f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -43,6 +43,7 @@ template typename DstResetCoordinateAfterRunFlags, // Sequence + typename IndexType, index_t ScatterDim = 1, bool OutputScatter = true, index_t ScatterWeightIdx = 3, @@ -153,7 +154,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, - StaticallyIndexedArray& scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -172,31 +172,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if(i.value == ScatterWeightIdx) - { - static_assert(SrcScalarPerVectors{}[Number{}] == 1, - "scatter weight dim, should only one vec"); - constexpr auto iScatter = - SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { - src_vectors(i).template AsType()(j) = - scatter_weights(Number{}); - }); - } - else if constexpr(SrcScalarPerVectors{}[i] == 1) - { - auto data_types = SrcDatas{}; - using DataType = remove_cvref_t; - const auto tmp = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); - } - else - { - src_vectors(i).template AsType()(I0) = - src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - } + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); }); constexpr auto get_elem_op_vec_len = []() { @@ -412,7 +389,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, + StaticallyIndexedArray& scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -420,8 +397,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // loop over space-filling curve static_for<0, dst_num_access, 1>{}([&](auto iAccess) { - auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; - auto scatter_offset = 0; + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + IndexType scatter_offset = 0; if constexpr(OutputScatter) { constexpr auto iScatter = @@ -431,8 +408,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); + // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + // dst_coords_[i]); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); dst_bufs(i).template Update( @@ -488,10 +467,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const SrcBuffers& src_bufs, const DstDescs& dst_descs, DstBuffers dst_bufs, - StaticallyIndexedArray& scatter_offsets, - StaticallyIndexedArray& scatter_weights) + StaticallyIndexedArray& scatter_offsets) { - RunRead(src_descs, src_bufs, scatter_weights); + RunRead(src_descs, src_bufs); RunWrite(dst_descs, dst_bufs, scatter_offsets); } diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 1abae56be4..429df2413f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,10 @@ enum struct WmmaInstr wmma_f32_16x16x16_f16_gfx12, wmma_f32_16x16x16_bf16_gfx12, wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, }; /* @@ -400,6 +404,146 @@ struct wmma_type +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + template + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; + } + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround static constexpr auto selected_wmma = wmma_type(), Number<32>{}>{}; @@ -612,14 +781,17 @@ struct WmmaGemm (is_same::value && is_same::value && is_same::value) || (is_same::value && is_same::value && - is_same::value) + is_same::value) || + ((is_same::value || is_same::value) && + (is_same::value || is_same::value) && + is_same::value) || #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || (is_same::value && is_same::value && - is_same::value) + (is_same::value && is_same::value && + is_same::value) || #endif - , + false, "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " - "(int8, int32) or (int4, int32)!"); + "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!"); static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { if constexpr(!TransposeC) { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index a638ca8608..06268f3cfb 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -845,15 +845,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -874,15 +883,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -890,14 +908,16 @@ template + bool is_single_rate_mfma = false, + bool is_scale_mfma = false> struct MfmaSelector { template + bool is_single_rate_mfma_ = false, + bool is_scale_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -1103,12 +1123,48 @@ struct MfmaSelector return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { @@ -1145,8 +1201,12 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = mfma_type< - GetMfma()>{}; + static constexpr auto selected_mfma = mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1194,7 +1254,8 @@ template + bool TransposeC = false, + bool is_scale_mfma = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -1225,7 +1286,7 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); } // XDL output supporting C = A * B @@ -1368,6 +1429,27 @@ struct XdlopsGemm }); } + template + __device__ void Run(const FloatA& p_a_wave, + const ScaleA& a_scale_thread, + const FloatB& p_b_wave, + const ScaleB& b_scale_thread, + FloatC& p_c_thread) const + { + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread); + } + }); + } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } __device__ static auto GetBlkIdx() @@ -1455,7 +1537,8 @@ struct XdlopsGemm KPack <= 4) || (is_same::value && KPack <= 8)) ? true - : false > {}; + : false, + is_scale_mfma > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp index 0ddfd0a7c8..a191c75099 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -187,7 +187,8 @@ struct TransformConvBwdDataToGemm_v1 WTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.WTilde_)}, ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)}, YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)}, - XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)} + XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)}, + batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_} { } @@ -203,7 +204,8 @@ struct TransformConvBwdDataToGemm_v1 const ConvSpatialDimsType& conv_filter_dilations, const ConvSpatialDimsType& input_left_pads, const ConvSpatialDimsType& input_right_pads, - const ConvSpatialDimsType& tildes) + const ConvSpatialDimsType& tildes, + const index_t batch_k = 1) : Hi_{c_g_n_c_wis_lengths[HIdx]}, Wi_{c_g_n_c_wis_lengths[WIdx]}, Ho_{a_g_n_k_wos_lengths[HIdx]}, @@ -231,7 +233,8 @@ struct TransformConvBwdDataToGemm_v1 InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]}, InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]}, IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]}, - IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]} + IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}, + batch_k_{batch_k} { static_assert(is_same_v> || is_same_v>); @@ -616,20 +619,22 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t AK0 = math::integer_divide_ceil(K_, AK1); + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; // A: output tensor const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( out_grid_desc, make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), - make_unmerge_transform(make_tuple(AK0, AK1))), + make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{})); const auto out_gemmak0_gemmm_gemmak1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( out_gemmak0_gemmmraw_gemmak1_grid_desc, - make_tuple(AK0, GemmMPerBlock, AK1), + make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1), Sequence{}); return out_gemmak0_gemmm_gemmak1_grid_desc; @@ -719,11 +724,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), make_pass_through_transform( out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -816,11 +825,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmMPerBlock), Sequence{}); - const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1; + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( out_gemmk_gemmm_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), make_pass_through_transform( out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -850,21 +863,23 @@ struct TransformConvBwdDataToGemm_v1 ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: Filter1x1Stride1Pad0) { - const index_t BK0 = math::integer_divide_ceil(K_, BK1); + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock; // B: weight tensor - const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = - transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, - make_tuple(BK0, GemmNPerBlock, BK1), + make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1), Sequence{}); return wei_gemmbk0_gemmn_gemmbk1_grid_desc; @@ -925,11 +940,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -1006,11 +1025,15 @@ struct TransformConvBwdDataToGemm_v1 make_tuple(GemmKPerBlock, GemmNPerBlock), Sequence{}); - const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1; + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( wei_gemmk_gemmn_padded_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), make_pass_through_transform( wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), @@ -1355,6 +1378,7 @@ struct TransformConvBwdDataToGemm_v1 IndexType ZTilde_, YTilde_, XTilde_; IndexType DTilde_, HTilde_, WTilde_; IndexType ZDot_, YDot_, XDot_; + index_t batch_k_; }; } // namespace tensor_operation diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 317f324e6d..62e3220b5a 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -581,7 +581,7 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type::typ tmp.template AsType()[i]); }); } -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; diff --git a/include/ck/utility/amd_buffer_addressing_builtins.hpp b/include/ck/utility/amd_buffer_addressing_builtins.hpp index 19869906dc..296c1d44d7 100644 --- a/include/ck/utility/amd_buffer_addressing_builtins.hpp +++ b/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -80,7 +80,7 @@ __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); // buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( @@ -88,7 +88,7 @@ __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); // buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( @@ -96,15 +96,15 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( int32x4_t rsrc, index_t voffset, index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32"); + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); // buffer atomic-add fp32 -__device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64( - double vdata, - int32x4_t rsrc, // dst_wave_buffer_resource - int voffset, // dst_thread_addr_offset - int soffset, // dst_wave_addr_offset - int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32"); +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); // memory coherency bit for buffer store/load instruction // check ISA manual for each GFX target @@ -827,7 +827,7 @@ llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, index_t voffset, index_t soffset, index_t offset, - index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32"); + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); #ifndef __HIPCC_RTC__ template diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 5c80c42d6c..d079639c6a 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -64,6 +64,9 @@ enum class ck_saturation_t namespace fp8_impl { typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); +typedef ushort ushortx2_t __attribute__((ext_vector_type(2))); +typedef short shortx2_t __attribute__((ext_vector_type(2))); typedef float float2_t __attribute__((ext_vector_type(2))); __host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) @@ -270,7 +273,7 @@ static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2(fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v) { const auto i16val = bit_cast(v); @@ -458,6 +461,510 @@ __is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) #endif } +#if defined(__gfx950__) +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.half_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +#else + ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +} +#endif // defined(__gfx950__) + #if CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 @@ -523,6 +1030,84 @@ static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = } return i8data; } + +template +static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0) +{ + if constexpr(stochastic_rounding) + { + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f32(v[0], rng), + cast_to_f8_from_f32(v[1], rng)}; + } + else + { + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; + } val0, val1; + + val0.fval = v[0]; + val1.fval = v[1]; + + unsigned int ival = 0; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0); + } + } + else + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0); + } + } + } + + // RNE CVT + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false); + } + else + { + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false); + } + + val0.i32val = ival; + + return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]}; + } +} #endif // CK_FP8_CVT_FAST_PATH // The conversion function is from rocblas @@ -797,6 +1382,7 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn * * \tparam interp interpretation of fp8 * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR * \param f float number * \return fp8_storage_t */ @@ -882,6 +1468,47 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #endif // CK_FP8_CVT_FAST_PATH } +/** + * \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param f vector of 2 floats + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f[0]); +#else + rng = prand_generator(reinterpret_cast(&f), f[0]); +#endif + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#else +__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#endif // CK_USE_OCP_FP8 + return fp8x2_storage_t{cvt_float_to_fp8(f[0]), + cvt_float_to_fp8(f[1])}; +#endif // CK_FP8_CVT_FAST_PATH +} + /** * \brief convert _Float16 to @p fp8_storage_t * @@ -900,87 +1527,168 @@ __host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) #endif { - return cvt_float_to_fp8(static_cast(x)); + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x); +#else + rng = prand_generator(reinterpret_cast(&x), x); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + static_cast(x)); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 _Float16 + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + float2_t{static_cast(x[0]), static_cast(x[1])}); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert bhalf_t to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x bhalf_t value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#else +__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x)); +#else + rng = prand_generator(reinterpret_cast(&x), static_cast(x)); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + bit_cast(uint32_t{x} << 16)); // convert value to float +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 bhalf_t + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#endif +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#else + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#endif + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#endif // defined(__gfx950__) + } +#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION } } // namespace fp8_impl -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_rne(X x); - -// convert fp32 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert fp32 to bf8 with rounding to nearest even -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with rounding to nearest even -template <> -inline __host__ __device__ f8_ocp_t f8_convert_rne(_Float16 x) -{ - return f8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8(x)}; -} - -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_rne(_Float16 x) -{ - return bf8_ocp_t{ - fp8_impl::cvt_half_t_to_fp8( - x)}; -} - -// Declare a template function for fp8 conversion using RNE -template -__host__ __device__ constexpr Y f8_convert_sr(X x); - -// convert fp32 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) -{ - return f8_ocp_t{ - fp8_impl::cvt_float_to_fp8( - x)}; -} - -// convert fp32 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) -{ - return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; -} - -// convert _Float16 to fp8 with stochastic rounding -template <> -inline __host__ __device__ f8_ocp_t f8_convert_sr(_Float16 x) -{ - return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - -// convert _Float16 to bf8 with stochastic rounding -template <> -inline __host__ __device__ bf8_ocp_t f8_convert_sr(_Float16 x) -{ - return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; -} - #if CK_USE_OCP_FP8 using f8_t = f8_ocp_t; using bf8_t = bf8_ocp_t; diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index aa519fb2be..e14c0d62a8 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #ifndef CK_AMD_WMMA_HPP #define CK_AMD_WMMA_HPP @@ -341,5 +341,101 @@ struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> } }; +// src: f8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck #endif diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0d4611becc..66c4958e1d 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -508,6 +508,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -520,9 +548,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) @@ -532,12 +560,91 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, // OPSEL scale_a, 0, // OPSEL scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); #else ignore = reg_a; ignore = scale_a; @@ -556,9 +663,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) @@ -568,7 +675,127 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_a, reg_b, reg_c.template AsType()[Number<0>{}], - 0, // cbsz + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} 0, // blgp 0, // OPSEL scale_a, @@ -616,6 +843,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp index 39407cb8f6..6c788fb41e 100644 --- a/include/ck/utility/blkgemmpipe_scheduler.hpp +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -48,6 +48,15 @@ enum struct TailNumber // prefetchstages Full, }; + +enum SchedulerGroup : uint32_t +{ + SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions + SCHED_GROUP_VMEM = 0x020, // Global memory operations + SCHED_GROUP_LDS_READ = 0x100, // LDS read operations + SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations +}; + template static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = f4x2_pk_t::type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 1a0ea27eab..1d80f196b5 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -24,7 +24,8 @@ template + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + typename IndexType = index_t> struct DynamicBuffer { using type = T; @@ -59,16 +60,16 @@ struct DynamicBuffer return BufferAddressSpace; } - __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; } + __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; } - __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } + __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; } template >::type, typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const + __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -79,7 +80,7 @@ struct DynamicBuffer "wrong! X should contain multiple T"); #if CK_USE_AMD_BUFFER_LOAD - bool constexpr use_amd_buffer_addressing = true; + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -140,7 +141,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) { if constexpr(Op == InMemoryDataOperationEnum::Set) { @@ -191,8 +192,8 @@ struct DynamicBuffer template __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, - index_t src_offset, - index_t dst_offset, + IndexType src_offset, + IndexType dst_offset, bool is_valid_element) const { // Copy data from global to LDS memory using direct loads. @@ -214,7 +215,7 @@ struct DynamicBuffer typename scalar_type>::type>::value || !is_native_type(), bool>::type = false> - __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; @@ -224,8 +225,8 @@ struct DynamicBuffer static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); -#if CK_USE_AMD_BUFFER_STORE - bool constexpr use_amd_buffer_addressing = true; +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -342,11 +343,12 @@ struct DynamicBuffer { if(is_valid_element) { -#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS +#if 0 X tmp = x; __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); #else + // if(i >= 2169041600) *c_style_pointer_cast(&p_data_[i]) = x; #endif } @@ -357,7 +359,7 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x) { using scalar_t = typename scalar_type>::type; @@ -378,12 +380,14 @@ struct DynamicBuffer (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) - bool constexpr use_amd_buffer_addressing = is_same_v, int32_t>; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, int32_t>; #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT bool constexpr use_amd_buffer_addressing = - is_same_v, float> || - (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || - (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); + sizeof(IndexType) <= sizeof(int32_t) && + (is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0)); #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -408,12 +412,12 @@ struct DynamicBuffer typename enable_if>::type, typename scalar_type>::type>::value, bool>::type = false> - __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + constexpr IndexType scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X should contain multiple T"); @@ -421,8 +425,9 @@ struct DynamicBuffer static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 - using scalar_t = typename scalar_type>::type; - bool constexpr use_amd_buffer_addressing = is_same_v, double>; + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, double>; #else bool constexpr use_amd_buffer_addressing = false; #endif @@ -455,6 +460,17 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el p, element_space_size}; } +template +__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p, + ElementSpaceSize element_space_size) +{ + return DynamicBuffer{ + p, element_space_size}; +} + template < AddressSpaceEnum BufferAddressSpace, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp index a692f533f8..f7d2a2f594 100644 --- a/include/ck/utility/e8m0.hpp +++ b/include/ck/utility/e8m0.hpp @@ -67,10 +67,10 @@ struct e8m0_bexp_t namespace utils { template -__host__ __device__ inline int get_exponent_value(T x); +__host__ __device__ inline constexpr int32_t get_exponent_value(T x); template <> -__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +__host__ __device__ inline constexpr int32_t get_exponent_value(e8m0_bexp_t x) { return x.data; } diff --git a/include/ck/utility/env.hpp b/include/ck/utility/env.hpp index 809f302f74..469fb70f10 100644 --- a/include/ck/utility/env.hpp +++ b/include/ck/utility/env.hpp @@ -184,4 +184,9 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck + +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #endif diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index b7b98c6455..9046a24a3a 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -39,7 +39,7 @@ static __device__ float cast_to_f32_from_f8_scaled(float scale, fp8_storage_t v) } template -static __device__ float2_t cast_to_f32x2_from_f8x2_scaled(float scale, fp8x2_storage_t v) +static __device__ float2_t cast_to_f32_from_f8_scaled(float scale, fp8x2_storage_t v) { const auto i16val = bit_cast(v); diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index f0a86f8750..cf7a3e8713 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -32,13 +32,13 @@ template __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data); template -__host__ __device__ inline int get_exponent_value(T x) +__host__ __device__ inline constexpr int32_t get_exponent_value(T x) { x >>= NumericUtils::mant; x &= ((1 << NumericUtils::exp) - 1); - return static_cast(x); + return static_cast(x); } template diff --git a/include/ck/utility/scaled_type_convert.hpp b/include/ck/utility/scaled_type_convert.hpp index 9a9c53caec..f3e2bd3dd9 100644 --- a/include/ck/utility/scaled_type_convert.hpp +++ b/include/ck/utility/scaled_type_convert.hpp @@ -67,7 +67,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t s #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), @@ -86,7 +86,7 @@ inline __host__ float2_t scaled_type_convert(e8m0_bexp_t #endif { #if CK_MX_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2_scaled( + return fp8_impl::cast_to_f32_from_f8_scaled( type_convert(scale), x.AsType()[Number<0>{}]); #else return float2_t{scaled_type_convert(scale, x.AsType()[Number<0>{}]), diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 99935a6d8d..497625f7e2 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -184,6 +184,21 @@ struct Sequence } }; +namespace impl { +template +struct __integer_sequence; + +template +struct __integer_sequence +{ + using seq_type = Sequence; +}; +} // namespace impl + +template +using make_index_sequence = + typename __make_integer_seq::seq_type; + // merge sequence template struct sequence_merge diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index b4f1545aa9..ec055fb2a2 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -11,8 +11,20 @@ namespace ck { +template +__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence) +{ + return make_tuple(f(Number{})...); +} + template __host__ __device__ constexpr auto generate_tuple(F&& f, Number) +{ + return generate_tuple_for(f, make_index_sequence{}); +} + +template +__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber) { return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, typename arithmetic_sequence_gen<0, N, 1>::type{}); diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b9aeb44999..04ae046ac8 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -117,7 +117,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else - return uint16_t(u.int32 >> 16); + return uint16_t(static_cast(x) >> 16); #endif } @@ -356,6 +356,180 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(half_t x #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8( + x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(float x) +{ + return bf8_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(float2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(half_t x) +{ + return f8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(half2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(half_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(half2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_sr(bhalf_t x) +{ + return f8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return f8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using stochastic rounding. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_sr(bhalf_t x) +{ + return bf8_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * stochastic rounding. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_sr(bhalf2_t x) +{ + return bf8x2_ocp_t{fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + // Declare a template function for fp8 conversion using RNE template __host__ __device__ constexpr Y f8_convert_rne(X x); @@ -466,6 +640,172 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(half_t #endif } +/** + * @brief Converts a float to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(float x) +{ + return f8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(float2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a float to a 8-bit float type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(float x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 floats to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 floats. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(float2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_float_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(half_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (f8_ocp_t) using rounding + * to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(half2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8(x)}; +} + +/** + * @brief Converts a half_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(half_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 half_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 half_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(half2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_half_t_to_fp8( + x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit float type (f8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t f8_convert_rne(bhalf_t x) +{ + return f8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (f8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 f8_ocp_t. + */ +template <> +inline __host__ __device__ f8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return f8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8(x)}; +} + +/** + * @brief Converts a bhalf_t to a 8-bit half_t type (bf8_ocp_t) using rounding to nearest/even. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t f8_convert_rne(bhalf_t x) +{ + return bf8_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + +/** + * @brief Converts a vector of 2 bhalf_t to a vector of 2 8-bit float types (bf8_ocp_t) using + * rounding to nearest/even. + * + * @param x The input vector of 2 bhalf_t. + * @return The converted vector of 2 bf8_ocp_t. + */ +template <> +inline __host__ __device__ bf8x2_ocp_t f8_convert_rne(bhalf2_t x) +{ + return bf8x2_ocp_t{ + fp8_impl::cvt_bhalf_t_to_fp8( + x)}; +} + // convert fp32 to fp8 template <> inline __host__ __device__ f8_fnuz_t type_convert(float x) @@ -477,17 +817,6 @@ inline __host__ __device__ f8_fnuz_t type_convert(float x) #endif } -// convert fp32 to fp8 -template <> -inline __host__ __device__ f8_ocp_t type_convert(float x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert fp8 to fp32 template <> inline __host__ __device__ float type_convert(f8_fnuz_t x) @@ -524,12 +853,39 @@ inline __host__ __device__ float2_t type_convert(f8x2_fnu #endif } +/** + * @brief Converts a f8_ocp_t value to a float value. + * + * @param x The input f8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(f8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 float values. + */ template <> inline __host__ __device__ float2_t type_convert(f8x2_ocp_t x) { #if CK_OCP_FP8_CVT_FAST_PATH - return fp8_impl::cast_to_f32x2_from_f8x2( - x.AsType()[Number<0>{}]); + return __builtin_amdgcn_cvt_pk_f32_fp8(bit_cast(x), false); #else return float2_t{fp8_impl::cast_from_f8( x.AsType()[Number<0>{}]), @@ -538,6 +894,229 @@ inline __host__ __device__ float2_t type_convert(f8x2_ocp_ #endif } +/** + * @brief Converts a f8_ocp_t value to a half_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + half2_t half_vec; + half_t half_arr[2]; + } output; + output.half_vec = __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.half_arr[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a f8_ocp_t value to a bhalf_t value. + * + * @param x The input f8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(f8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 f8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 f8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(f8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a float value. + * + * @param x The input bf8_ocp_t value. + * @return The converted float value. + */ +template <> +inline __host__ __device__ float type_convert(bf8_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + union + { + unsigned int i32val; + fp8_storage_t i8val[4]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 float values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 float values. + */ +template <> +inline __host__ __device__ float2_t type_convert(bf8x2_ocp_t x) +{ +#if CK_OCP_FP8_CVT_FAST_PATH + return __builtin_amdgcn_cvt_pk_f32_bf8(bit_cast(x), false); +#else + return float2_t{fp8_impl::cast_from_f8( + x.AsType()[Number<0>{}]), + fp8_impl::cast_from_f8( + x.AsType()[Number<1>{}])}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a half_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted half_t value. + */ +template <> +inline __host__ __device__ half_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } val; + val.i8val[0] = x.data; + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(val.i16val, /*scale*/ 1.f, 0)[0]; +#else + return fp8_impl::cast_from_f8(x.data); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 half_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 half_t values. + */ +template <> +inline __host__ __device__ half2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_f16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return half2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + +/** + * @brief Converts a bf8_ocp_t value to a bhalf_t value. + * + * @param x The input bf8_ocp_t value. + * @return The converted bhalf_t value. + */ +template <> +inline __host__ __device__ bhalf_t type_convert(bf8_ocp_t x) +{ +#if defined(__gfx950__) + union + { + uint16_t i16val; + fp8_storage_t i8val[2]; + } input; + input.i8val[0] = x.data; + + union + { + bhalf2_t bhalf_vec; + bhalf_t bhalf_arr[2]; + } output; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); + + return output.bhalf_arr[0]; +#else + return type_convert( + fp8_impl::cast_from_f8(x.data)); +#endif +} + +/** + * @brief Converts a vector of 2 bf8_ocp_t values to a vector of 2 bhalf_t values. + * + * @param x The input vector of 2 bf8_ocp_t values. + * @return The converted vector of 2 bhalf_t values. + */ +template <> +inline __host__ __device__ bhalf2_t type_convert(bf8x2_ocp_t x) +{ +#if defined(__gfx950__) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(bit_cast(x), /*scale*/ 1.f, 0); +#else + return bhalf2_t{type_convert(float(x.AsType()[Number<0>{}])), + type_convert(float(x.AsType()[Number<1>{}]))}; +#endif +} + template <> inline __host__ __device__ float2_t type_convert(pk_i4_t x) { @@ -610,7 +1189,12 @@ inline __host__ __device__ f8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to fp8 +/** + * @brief Converts a half_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted f8_ocp_t value. + */ template <> inline __host__ __device__ f8_ocp_t type_convert(half_t x) { @@ -621,6 +1205,22 @@ inline __host__ __device__ f8_ocp_t type_convert(half_t x) #endif } +/** + * @brief Converts a half_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input half_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(half_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert fp8 to fp16 template <> inline __host__ __device__ half_t type_convert(f8_fnuz_t x) @@ -645,7 +1245,28 @@ inline __host__ __device__ bf8_fnuz_t type_convert(float x) #endif } -// convert fp32 to bf8 +/** + * @brief Converts a float value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(float x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a float value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input float value. + * @return The converted bf8_ocp_t value. + */ template <> inline __host__ __device__ bf8_ocp_t type_convert(float x) { @@ -656,6 +1277,38 @@ inline __host__ __device__ bf8_ocp_t type_convert(float x) #endif } +/** + * @brief Converts a bhalf_t value to a f8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted f8_ocp_t value. + */ +template <> +inline __host__ __device__ f8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + +/** + * @brief Converts a bhalf_t value to a bf8_ocp_t value with rounding determined by a flag. + * + * @param x The input bhalf_t value. + * @return The converted bf8_ocp_t value. + */ +template <> +inline __host__ __device__ bf8_ocp_t type_convert(bhalf_t x) +{ +#if CK_USE_SR_F8_CONVERSION + return f8_convert_sr(x); +#else + return f8_convert_rne(x); +#endif +} + // convert bf8 to fp32 template <> inline __host__ __device__ float type_convert(bf8_fnuz_t x) @@ -683,17 +1336,6 @@ inline __host__ __device__ bf8_fnuz_t type_convert(half_t x) #endif } -// convert fp16 to bf8 -template <> -inline __host__ __device__ bf8_ocp_t type_convert(half_t x) -{ -#if CK_USE_SR_F8_CONVERSION - return f8_convert_sr(x); -#else - return f8_convert_rne(x); -#endif -} - // convert bf8 to fp16 template <> inline __host__ __device__ half_t type_convert(bf8_fnuz_t x) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 821b3a8e84..2ea8bf15a7 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -9,10 +9,10 @@ #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/algorithm/static_encoding_pattern.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" #include "ck_tile/core/arch/arch.hpp" #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/arch/workgroup_barrier.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/container/container_helper.hpp" @@ -53,6 +53,7 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/tensor/tile_scatter_gather.hpp" #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/tensor/tile_window_linear.hpp" #include "ck_tile/core/tensor/tile_window_utils.hpp" diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 78884f3f9f..b56bda3741 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -73,10 +73,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim // # of rows in Y dim accessed by single wavefront in one iteration static constexpr index_t Y1 = warp_size / X0; @@ -124,10 +125,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); @@ -173,10 +175,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); static constexpr index_t Y1 = num_warps; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 33faa3a18b..5d6d6ce348 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -14,6 +14,15 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" +// This attribute gives a hint to the compiler that a branch is likely to be taken. +// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would +// have been generated. +#if __cplusplus >= 202002L +#define LIKELY(x) (x) [[likely]] +#else +#define LIKELY(x) (__builtin_expect(!!(x), 1)) +#endif + namespace ck_tile { // 128 bit SGPRs to supply buffer resource in buffer instructions @@ -58,10 +67,36 @@ template<> struct buffer_load_trait<4 , thread_buffer> { using payloa // TODO: glc/slc/... template struct buffer_load; + +template +struct buffer_load_if; + +template +struct buffer_store; + +template +struct buffer_store_if; + #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" // TODO: strict aliasing rule seems fail when reinterpret_cast between vector type // (exp_vector_type(xxx)) + +#define HAS_RAW_BUFFER_BUILTINS \ + __has_builtin(__builtin_amdgcn_raw_buffer_load_b32) && \ + __has_builtin(__builtin_amdgcn_make_buffer_rsrc) && \ + __has_builtin(__builtin_amdgcn_raw_buffer_store_b32) + +#if HAS_RAW_BUFFER_BUILTINS +CK_TILE_DEVICE __amdgpu_buffer_rsrc_t cast_to_amdgpu_buffer_rsrc_t(int32x4_t res) +{ + __amdgpu_buffer_rsrc_t as_rsrc; + static_assert(sizeof(res) == sizeof(as_rsrc) && "Size of buffer resource should match"); + memcpy(&as_rsrc, &res, sizeof(res)); + return as_rsrc; +} +#endif + template struct buffer_load<16, pre_nop> { @@ -76,6 +111,11 @@ struct buffer_load<16, pre_nop> { static_assert(sizeof(T) == 16); using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b128( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" @@ -87,6 +127,7 @@ struct buffer_load<16, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -104,6 +145,11 @@ struct buffer_load<8, pre_nop> { static_assert(sizeof(T) == 8); using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b64( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" @@ -115,6 +161,7 @@ struct buffer_load<8, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -132,6 +179,12 @@ struct buffer_load<4, pre_nop> { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b32( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_dword %0, %1, %2, 0 offen offset:%3" @@ -143,6 +196,7 @@ struct buffer_load<4, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -160,6 +214,12 @@ struct buffer_load<2, pre_nop> { static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b16( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" @@ -171,6 +231,7 @@ struct buffer_load<2, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -188,6 +249,11 @@ struct buffer_load<1, pre_nop> { static_assert(sizeof(T) == 4); using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + reinterpret_cast(value) = __builtin_amdgcn_raw_buffer_load_b16( + cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else if constexpr(pre_nop) asm volatile("s_nop 4\n" "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" @@ -199,12 +265,31 @@ struct buffer_load<1, pre_nop> : "+v"(reinterpret_cast(value)) : "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; -template -struct buffer_load_if; - +#if HAS_RAW_BUFFER_BUILTINS +template +struct buffer_load_if +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0, + bool_constant = {}) + { + if LIKELY(1 <= flag) + { + buffer_load{}( + value, res, v_offset, s_offset, i_offset, flag, bool_constant{}); + } + } +}; +#else template struct buffer_load_if<16, pre_nop> { @@ -214,12 +299,12 @@ struct buffer_load_if<16, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); if constexpr(pre_nop) asm volatile("s_nop 4\n" @@ -248,12 +333,12 @@ struct buffer_load_if<8, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -281,12 +366,12 @@ struct buffer_load_if<4, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -314,12 +399,12 @@ struct buffer_load_if<2, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -347,12 +432,12 @@ struct buffer_load_if<1, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -370,9 +455,9 @@ struct buffer_load_if<1, pre_nop> : "memory"); } }; +#endif + #pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" -template -struct buffer_store; template <> struct buffer_store<16> @@ -387,10 +472,16 @@ struct buffer_store<16> { static_assert(sizeof(T) == 16); using mbuf_t = fp32x4_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b128( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -407,10 +498,16 @@ struct buffer_store<8> { static_assert(sizeof(T) == 8); using mbuf_t = fp32x2_t; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b64( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -427,10 +524,16 @@ struct buffer_store<4> { static_assert(sizeof(T) == 4); using mbuf_t = float; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b32( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -447,10 +550,16 @@ struct buffer_store<2> { static_assert(sizeof(T) == 2); using mbuf_t = short; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b16( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; @@ -467,16 +576,38 @@ struct buffer_store<1> { static_assert(sizeof(T) == 4); using mbuf_t = float; +#if HAS_RAW_BUFFER_BUILTINS + index_t s_offset = i_offset; + __builtin_amdgcn_raw_buffer_store_b8( + bit_cast(value), cast_to_amdgpu_buffer_rsrc_t(res), v_offset, s_offset, 0); +#else asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" : : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) : "memory"); +#endif } }; +#if HAS_RAW_BUFFER_BUILTINS template -struct buffer_store_if; - +struct buffer_store_if +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + if LIKELY(1 <= flag) + { + buffer_store{}(value, res, v_offset, s_offset, i_offset); + } + } +}; +#else template <> struct buffer_store_if<16> { @@ -490,7 +621,7 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -547,7 +678,7 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -575,7 +706,7 @@ struct buffer_store_if<2> { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -603,7 +734,7 @@ struct buffer_store_if<1> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -617,6 +748,7 @@ struct buffer_store_if<1> : "memory"); } }; +#endif CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) { diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp deleted file mode 100644 index 0b9956cd01..0000000000 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ /dev/null @@ -1,2559 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#if CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN - -#include "ck_tile/core/numeric/integer.hpp" -#include "ck_tile/core/numeric/integral_constant.hpp" -#include "ck_tile/core/numeric/vector_type.hpp" -#include "ck_tile/core/container/container_helper.hpp" -#include "ck_tile/core/container/thread_buffer.hpp" -#include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/utility/functional.hpp" - -namespace ck_tile { - -// 128 bit SGPRs to supply buffer resource in buffer instructions -// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions -struct __attribute__((packed)) buffer_resource -{ - const void* ptr; - uint32_t range; - uint32_t config; -}; - -CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) -{ - buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; - int32x4_t r = __builtin_bit_cast(int32x4_t, res); - r.x = __builtin_amdgcn_readfirstlane(r.x); - r.y = __builtin_amdgcn_readfirstlane(r.y); - r.z = __builtin_amdgcn_readfirstlane(r.z); - r.w = __builtin_amdgcn_readfirstlane(r.w); - return r; -} - -namespace impl { -// below type indicate the data type used for buffer load inline asm -// clang-format off -template struct buffer_load_trait; - -template struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; }; -template struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; }; -template struct buffer_load_trait<4 , T> { using payload_t = float; }; -template struct buffer_load_trait<2 , T> { using payload_t = float; }; -template struct buffer_load_trait<1 , T> { using payload_t = float; }; - -#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA -template<> struct buffer_load_trait<16, thread_buffer> { using payload_t = bf16x8_t; }; -template<> struct buffer_load_trait<8 , thread_buffer> { using payload_t = bf16x4_t; }; -template<> struct buffer_load_trait<4 , thread_buffer> { using payload_t = bf16x2_t; }; -#endif -// clang-format on -} // namespace impl - -// TODO: glc/slc/... -template -struct buffer_load; -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" -// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type -// (exp_vector_type(xxx)) -template -struct buffer_load<16, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 16); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<8, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 8); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<4, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_dword %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<2, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_ushort %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load<1, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - else - asm volatile("buffer_load_ubyte %0, %1, %2, 0 offen offset:%3" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_load_if; - -template -struct buffer_load_if<16, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 16); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; - static_assert(sizeof(mbuf_t) == sizeof(T)); - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<8, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 8); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<4, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<2, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ushort %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; - -template -struct buffer_load_if<1, pre_nop> -{ - template - CK_TILE_DEVICE void operator()(T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 0, - bool_constant = {}) - { - static_assert(sizeof(T) == 4); - auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - else - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_load_ubyte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : "+v"(reinterpret_cast(value)) - : "v"(v_offset), "s"(res), "n"(i_offset), "v"(flag), "s"(saved_exec) - : "memory"); - } -}; -#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" -template -struct buffer_store; - -template <> -struct buffer_store<16> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 16); - using mbuf_t = fp32x4_t; - asm volatile("buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<8> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 8); - using mbuf_t = fp32x2_t; - asm volatile("buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<4> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_store_dword %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<2> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 2); - using mbuf_t = short; - asm volatile("buffer_store_short %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct buffer_store<1> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag*/ = 1) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("buffer_store_byte %0, %1, %2, 0 offen offset:%3" - : - : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "n"(i_offset) - : "memory"); - } -}; - -template -struct buffer_store_if; - -template <> -struct buffer_store_if<16> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 16); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<8> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 8); - auto save_exec = __builtin_amdgcn_read_exec(); - // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch - using mbuf_t = ext_vector_t; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dwordx2 %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<4> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<2> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 2); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template <> -struct buffer_store_if<1> -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(bit_cast(value)), - "v"(v_offset), - "s"(res), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory"); -} - -template -struct buffer_atomic_add_if; - -template -struct buffer_atomic_add_if -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t flag = 1) - { - static_assert(sizeof(T) == 4); - auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; - asm volatile("v_cmpx_le_u32 exec, 1, %4\n" - "global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n" - "s_mov_b64 exec %5" - : - : "v"(v_offset), - "v"(bit_cast(value)), - "s"(res.xy), - "n"(i_offset), - "v"(flag), - "s"(save_exec) - : "memory"); - } -}; - -template -struct buffer_atomic_add; - -template -struct buffer_atomic_add -{ - template - CK_TILE_DEVICE void operator()(const T& value, - int32x4_t res /*buffer resource*/, - index_t v_offset, - index_t /*s_offset*/, - index_t i_offset /*max 0xFFF*/, - index_t /*flag = 1*/) - { - static_assert(sizeof(T) == 4); - using mbuf_t = float; - asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3" - : - : "v"(v_offset), "v"(bit_cast(value)), "s"(res.xy), "n"(i_offset) - : "memory"); - } -}; - -namespace impl { -// below type indicate the data type used for buffer load inline asm -// clang-format off -template struct smem_load_trait; - -template struct smem_load_trait<16, T> { using payload_t = fp32x4_t; }; -template struct smem_load_trait<8 , T> { using payload_t = fp32x2_t; }; -template struct smem_load_trait<4 , T> { using payload_t = float; }; -template struct smem_load_trait<2 , T> { using payload_t = float; }; -template struct smem_load_trait<1 , T> { using payload_t = float; }; - -// clang-format on -} // namespace impl - -// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :) -template -struct smem_load; - -template <> -struct smem_load<16> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 16); - using mbuf_t = typename impl::smem_load_trait<16, T>::payload_t; - asm volatile("ds_read_b128 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<8> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 8); - using mbuf_t = typename impl::smem_load_trait<8, T>::payload_t; - asm volatile("ds_read_b64 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<4> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::smem_load_trait<4, T>::payload_t; - asm volatile("ds_read_b32 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<2> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually - using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; - asm volatile("ds_read_u16 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -template <> -struct smem_load<1> -{ - template - CK_TILE_DEVICE void operator()(T& value, index_t v_offset, index_t i_offset) - { - static_assert(sizeof(T) == 4); - using mbuf_t = typename impl::smem_load_trait<1, T>::payload_t; - asm volatile("ds_read_u8 %0, %1 offset:%2" - : "=v"(reinterpret_cast(value)) // ! direct write - : "v"(v_offset), "n"(i_offset) - : "memory"); - } -}; - -// clang-format off -namespace impl{ - -// can't use "+v" since there could be potential extra move(read/write) -// use "v" can help remove such duplicated moves -// besides, fake this as "memory" operation to force later valu after this fence -// TODO: may have scratch (because this is memory?) -// need to reduce extra move inside compiler -template -CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) -{ - constexpr auto kSize = remove_cvref_t::size(); - static_for<0, kSize, 1>{}([&](auto i){ - asm volatile(" " : : "v"(b.get(number{})) : "memory"); - }); -} -#if 1 -// below specialization just merge size() of dwords into single section -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), - "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), - "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory"); -} - -template<> -CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array& b) -{ - asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), - "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), - "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), - "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})), - "v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})), - "v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})), - "v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})), - "v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory"); -} -#endif -CK_TILE_DEVICE void insert_dummy_dep() {} - -template -CK_TILE_DEVICE void insert_dummy_dep(T & buffer) -{ - // TODO: indeed we expect T to be multiple of dword. subdword is always buggy - using da_type = array; - auto & dummy = reinterpret_cast(buffer); - insert_dummy_dep_per_dword(dummy); -} - -template -CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) -{ - insert_dummy_dep(bx); - insert_dummy_dep(by...); -} -} -// clang-format on -template -CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); - impl::insert_dummy_dep(o...); -} - -CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -// buffer load i8 -CK_TILE_DEVICE_EXTERN int8_t -llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8.v4i32"); - -CK_TILE_DEVICE_EXTERN int8x2_t -llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8.v4i32"); - -CK_TILE_DEVICE_EXTERN int8x4_t -llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8.v4i32"); - -// buffer load i16 -CK_TILE_DEVICE_EXTERN int16_t -llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN int16x2_t -llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN int16x4_t -llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16.v4i32"); - -// buffer load i32 -CK_TILE_DEVICE_EXTERN int32_t -llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32.v4i32"); - -CK_TILE_DEVICE_EXTERN int32x2_t -llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32.v4i32"); - -CK_TILE_DEVICE_EXTERN int32x4_t -llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32.v4i32"); - -// buffer load fp16 -CK_TILE_DEVICE_EXTERN _Float16 -llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16.v4i32"); - -CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_load_fp16x2( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16.v4i32"); - -CK_TILE_DEVICE_EXTERN fp16x4_t llvm_amdgcn_raw_buffer_load_fp16x4( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16.v4i32"); - -// buffer load fp32 -CK_TILE_DEVICE_EXTERN float -llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32.v4i32"); - -CK_TILE_DEVICE_EXTERN fp32x2_t llvm_amdgcn_raw_buffer_load_fp32x2( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32.v4i32"); - -CK_TILE_DEVICE_EXTERN fp32x4_t llvm_amdgcn_raw_buffer_load_fp32x4( - int32x4_t srsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32.v4i32"); - -// buffer store i8 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8.v4i32"); - -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8.v4i32"); - -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8.v4i32"); - -// buffer store i16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x2( - int16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i16x4( - int16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32"); - -// buffer store i32 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32.v4i32"); - -// buffer store ui16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_ui16(uint16_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x2( - uint16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_ui16x4( - uint16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x2( - int32x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_i32x4( - int32x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32.v4i32"); - -// buffer store fp16 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x2( - fp16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp16x4( - fp16x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16.v4i32"); - -// buffer store fp32 -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_store_fp32(float vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x2( - fp32x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32.v4i32"); - -CK_TILE_DEVICE_EXTERN void llvm_amdgcn_raw_buffer_store_fp32x4( - fp32x4_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32.v4i32"); - -// buffer atomic-add fp16 -CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - fp16x2_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16.v4i32"); - -// buffer atomic-add i32 -CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( - int32_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32.v4i32"); - -// buffer atomic-add fp32 -CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32( - float vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32.v4i32"); - -// buffer atomic-max fp64 -CK_TILE_DEVICE_EXTERN double llvm_amdgcn_raw_buffer_atomic_max_fp64( - double vdata, - int32x4_t rsrc, // dst_wave_buffer_resource - int voffset, // dst_thread_addr_offset - int soffset, // dst_wave_addr_offset - int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64.v4i32"); - -// Direct loads from global to LDS. -CK_TILE_DEVICE_EXTERN void -llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, - __attribute__((address_space(3))) uint32_t* lds_ptr, - index_t size, - index_t voffset, - index_t soffset, - index_t offset, - index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds.v4i32"); - -template -CK_TILE_DEVICE void async_buffer_load_dword_v(void* smem, - int32x4_t rsrc, - index_t voffset, - index_t /*soffset*/, - index_t ioffset /*max 0xFFF*/, - index_t /*flag*/ = 0, - bool_constant = {}) -{ - if constexpr(pre_nop) - asm volatile("s_nop 4\n" - "buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); - else - asm volatile("buffer_load_dword %1, %2, 0 offen offset:%3 lds" - : "=r"(smem) /*dummy dependency for smem*/ - : "v"(voffset), "s"(rsrc), "n"(ioffset) - : "memory"); -} - -CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) -{ - asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); -} - -// memory coherency bit for buffer store/load instruction -// check ISA manual for each GFX target -// e.g. for -// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, -// page 67~68 -enum struct amd_buffer_coherence_enum -{ - coherence_default = 0, // default value - glc = 1, - slc = 2, - glc_slc = 3, -}; - -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) -{ - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, - "wrong! not implemented"); - - using rtn_type = thread_buffer; - - if constexpr(N == 1) - { - return bit_cast(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - - int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 4) - { - int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 8) - { - int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - else if constexpr(N == 16) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - return bit_cast(tmp); - } - else if constexpr(N == 32) - { - int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - int32x4_t tmp1 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = tmp0; - tmp.template get_as()(number<1>{}) = tmp1; - - return bit_cast(tmp); - } - else if constexpr(N == 64) - { - int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - int32x4_t tmp1 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(int32_t), - static_cast(coherence)); - int32x4_t tmp2 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(int32_t), - static_cast(coherence)); - int32x4_t tmp3 = - llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(int32_t), - static_cast(coherence)); - - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = tmp0; - tmp.template get_as()(number<1>{}) = tmp1; - tmp.template get_as()(number<2>{}) = tmp2; - tmp.template get_as()(number<3>{}) = tmp3; - - return bit_cast(tmp); - } -} - -#ifndef BUFFER_LOAD_USE_INLINEASM -#define BUFFER_LOAD_USE_INLINEASM 0 -#endif - -template -CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset) -{ - static_assert( - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)), - "wrong! not implemented"); - - using rtn_type = thread_buffer; - - if constexpr(std::is_same::value) // fp32 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - return tmp; - } - else if constexpr(N == 16) - { - thread_buffer tmp; - - tmp.template get_as()(number<0>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - tmp.template get_as()(number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<2>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 8 * sizeof(float), - static_cast(coherence)); - - tmp.template get_as()(number<3>{}) = - llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 12 * sizeof(float), - static_cast(coherence)); - - return tmp; - } - } - else if constexpr(std::is_same::value) // fp16 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - // use fp32 load to mimic fp16 load - fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else if constexpr(std::is_same::value) // bf16 - { - if constexpr(N == 1) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 2) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 4) - { - return bit_cast( - llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence))); - } - else if constexpr(N == 8) - { - int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - static_cast(coherence)); - - return bit_cast(tmp); - } - } - else // other datatype - { - auto raw_data = amd_buffer_load_impl_with_bytes( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); - - return bit_cast(raw_data); - } -} - -template -CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_linear_addr_offset, - index_t flag = 0, - bool_constant = {}) -{ - constexpr index_t bytes = sizeof(T) * N; - static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, - "wrong! not supported by buffer_load instruction"); - - using type = thread_buffer; - if constexpr(oob_conditional_check) - { - buffer_load_if{}(dst, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_linear_addr_offset, - flag, - bool_constant{}); - } - else - { - buffer_load{}(dst, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_linear_addr_offset, - flag, - bool_constant{}); - } -} - -template -CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - bool_constant = {}) -{ - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - - async_buffer_load_dword_v(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - 0, - bool_constant{}); -} - -template -CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, - int32x4_t src_wave_buffer_resource, - index_t src_thread_addr_offset, - index_t src_wave_addr_offset, - index_t src_immediate_addr_offset = 0, - index_t flag = 0, - bool_constant = {}) -{ - static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); - - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? v_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource, - smem, - sizeof(uint32_t), - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, - "wrong! not implemented"); - - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i8(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 16) - { - llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 32) - { - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); - } - else if constexpr(N == 64) - { - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 4, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 8, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i32x4( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t) * 12, - static_cast(coherence)); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert( - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && - (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), - "wrong! not implemented"); - - if constexpr(std::is_same::value) // fp32 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_fp32x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - llvm_amdgcn_raw_buffer_store_fp32x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(float), - static_cast(coherence)); - } - } - else if constexpr(std::is_same::value) // fp16 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { -#if 0 - thread_buffer tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(fp16_t), - static_cast(coherence)); -#else - llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); -#endif - } - } - else if constexpr(std::is_same::value) // bf16 - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_i16x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(bf16_t), - static_cast(coherence)); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_ui16(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_ui16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_ui16x4(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - } - else if constexpr(N == 8) - { - llvm_amdgcn_raw_buffer_store_ui16x4( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - static_cast(coherence)); - - llvm_amdgcn_raw_buffer_store_ui16x4( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(uint16_t), - static_cast(coherence)); - } - } - else - { - using r_t = thread_buffer; - - amd_buffer_store_impl_with_bytes(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset); - } -} - -template -CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset, - index_t dst_linear_addr_offset, - index_t is_valid_element = 1) -{ - constexpr index_t bytes = sizeof(T) * N; - static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, - "wrong! not supported by buffer_store instruction"); - - using type = thread_buffer; - if constexpr(oob_conditional_check) - { - buffer_store_if{}(dst_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - dst_linear_addr_offset, - is_valid_element); - } - else - { - buffer_store{}(dst_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - dst_linear_addr_offset); - } -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)) || - (std::is_same::value && (N == 2 || N == 4 || N == 8)) || - (std::is_same::value && (N == 1 || N == 2 || N == 4)), - "wrong! not implemented"); - - if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(float), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(float), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_fp32( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(float), - 0); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - static_for<0, 2, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - src_thread_data.template get_as()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(fp16x2_t), - 0); - }); - } - else if constexpr(N == 8) - { - static_for<0, 4, 1>{}([&](auto i) { - llvm_amdgcn_raw_buffer_atomic_add_fp16x2( - src_thread_data.template get_as()[i], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + i * sizeof(fp16x2_t), - 0); - }); - } - } - else if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(int32_t), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(int32_t), - 0); - - llvm_amdgcn_raw_buffer_atomic_add_i32( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(int32_t), - 0); - } - } -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer src_thread_data, - int32x4_t dst_wave_buffer_resource, - index_t dst_thread_addr_offset, - index_t dst_wave_addr_offset) -{ - static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)), - "wrong! not implemented"); - if constexpr(std::is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast(src_thread_data), - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + sizeof(double), - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<2>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 2 * sizeof(double), - 0); - - llvm_amdgcn_raw_buffer_atomic_max_fp64( - src_thread_data.template get_as()[number<3>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 3 * sizeof(double), - 0); - } - } -} - -// buffer_load requires: -// 1) p_src_wave must point to global memory space -// 2) p_src_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -// oob_conditional_check : dynamic check if out-of-bound -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = [&]() { - if constexpr(oob_conditional_check) - return src_thread_element_valid ? 0 : 0x80000000; - else - return 0; - }(); - return amd_buffer_load_impl( - src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); -#else - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; - else - return tmp; -#endif -} - -// buffer_load requires: -// 1) p_src_wave must point to global memory space -// 2) p_src_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE thread_buffer -amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, - index_t src_thread_element_offset, - bool src_thread_element_valid, - index_t src_element_space_size, - T customized_value) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - - thread_buffer tmp = - amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); - - if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{customized_value}; - else - return tmp; -} - -template -CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t src_element_space_size, - index_t is_valid_element = 0, - bool_constant = {}) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_buffer_load_raw_impl( - dst, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t is_valid_element = 0, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_buffer_load_raw_impl( - dst, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// unfortunately async copy can not make sure invalid data is zero inside LDS -// ... unless people manually write zero to LDS at the proper address. -// so not support invalid_element check for now. -// buffer_load OOB still working. -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, - const T* p_src_wave, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - index_t src_element_space_size, - bool_constant = {}) -{ - const int32x4_t src_wave_buffer_resource = - make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); - - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load_impl(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load_impl(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - bool_constant{}); -} - -// This version support buffer resource as input arg -template -CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem, - const int32x4_t src_wave_buffer_resource, - index_t src_thread_element_offset, - index_t src_linear_element_offset, - bool is_valid_element, - bool_constant = {}) -{ - index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); - index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T); - - amd_async_buffer_load(smem, - src_wave_buffer_resource, - src_thread_addr_offset, - 0, - src_linear_addr_offset, - is_valid_element, - bool_constant{}); -} - -// buffer_store requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_store(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = [&]() { - if constexpr(oob_conditional_check) - return dst_thread_element_valid ? 0 : 0x80000000; - else - return 0; - }(); - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if constexpr(oob_conditional_check) - { - if(dst_thread_element_valid) - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } - } - else - { - amd_buffer_store_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const index_t dst_linear_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); - - amd_buffer_store_raw_impl(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - dst_thread_element_valid); -} - -// buffer_atomic_add requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) - { - amd_buffer_atomic_add_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const index_t dst_linear_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size, - bool_constant = {}) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T); - - if constexpr(oob_conditional_check) - { - buffer_atomic_add_if{}(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - dst_thread_element_valid); - } - else - { - buffer_atomic_add{}(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - 0, - dst_linear_addr_offset, - 1); - } -} - -// buffer_atomic_max requires: -// 1) p_dst_wave must point to global memory -// 2) p_dst_wave must be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_data, - T* p_dst_wave, - const index_t dst_thread_element_offset, - const bool dst_thread_element_valid, - const index_t dst_element_space_size) -{ - const int32x4_t dst_wave_buffer_resource = - make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); - - index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); - -#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - - amd_buffer_atomic_max_impl( - src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); -#else - if(dst_thread_element_valid) - { - amd_buffer_atomic_max_impl( - src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); - } -#endif -} - -template -CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, - const index_t global_offset, - T* lds_base_ptr, - const index_t lds_offset, - const bool is_valid, - const index_t src_element_space_size) -{ - // Direct loads require that each thread reads and writes exactly a single DWORD. - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; - static_assert(bytes_per_thread == dword_bytes); - - const uint32_t* global_ptr = - reinterpret_cast(reinterpret_cast(global_base_ptr)); - const int32x4_t src_resource = - make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T)); - const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; - -#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM - T* lds_ptr = lds_base_ptr + lds_offset; - auto const lds_ptr_sgpr = - __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); - asm volatile("s_mov_b32 m0, %0; \n\t" - "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), - "v"(global_offset_bytes), - "s"(src_resource) - : "memory"); -#else - // LDS pointer must be attributed with the LDS address space. - __attribute__((address_space(3))) uint32_t* lds_ptr = - reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( - reinterpret_cast(lds_base_ptr + lds_offset)); - - llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); -#endif -} - -} // namespace ck_tile - -#endif // CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 09de5f325f..1d3cf5c010 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -154,4 +154,13 @@ __host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_addres #pragma clang diagnostic pop } +CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity() +{ +#if defined(__gfx950__) + return 163840; +#else + return 65536; +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/core/arch/workgroup_barrier.hpp b/include/ck_tile/core/arch/workgroup_barrier.hpp new file mode 100644 index 0000000000..827a490fcb --- /dev/null +++ b/include/ck_tile/core/arch/workgroup_barrier.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +struct workgroup_barrier +{ + CK_TILE_DEVICE workgroup_barrier(uint32_t* ptr) : base_ptr(ptr) {} + + CK_TILE_DEVICE uint32_t ld(uint32_t offset = 0) + { + return __atomic_load_n(base_ptr + offset, __ATOMIC_RELAXED); + } + + CK_TILE_DEVICE void wait_eq(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) != value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_lt(uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(ld(offset) < value) {} + } + __syncthreads(); + } + + CK_TILE_DEVICE void wait_set(uint32_t compare, uint32_t value, uint32_t offset = 0) + { + if(threadIdx.x == 0) + { + while(atomicCAS(base_ptr + offset, compare, value) != compare) {} + } + __syncthreads(); + } + + // enter critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void aquire(uint32_t offset = 0) { wait_set(offset, 0, 1); } + + // exit critical zoon, assume buffer is zero when launch kernel + CK_TILE_DEVICE void release(uint32_t offset = 0) { wait_set(offset, 1, 0); } + + CK_TILE_DEVICE void inc(uint32_t offset = 0) + { + __syncthreads(); + if(threadIdx.x == 0) + { + atomicAdd(base_ptr + offset, 1); + } + } + + uint32_t* base_ptr; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index b1d201e30e..27133fa847 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -15,7 +15,8 @@ #define __gfx103__ #endif #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ - defined(__gfx1103__) || defined(__gfx11_generic__) + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx11_generic__) #define __gfx11__ #endif #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) @@ -28,12 +29,6 @@ #include "hip/hip_fp16.h" #endif -#include "ck_tile/core/utility/env.hpp" - -// environment variable to enable logging: -// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED -CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) - #ifdef __HIPCC__ #define CK_TILE_HOST inline __host__ #define CK_TILE_DEVICE inline __device__ @@ -262,5 +257,5 @@ CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) #endif #ifndef CK_TILE_WA_ISSUE_2028 -#define CK_TILE_WA_ISSUE_2028 1 +#define CK_TILE_WA_ISSUE_2028 0 #endif diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp index fa63597db4..94aa40e278 100644 --- a/include/ck_tile/core/container/array.hpp +++ b/include/ck_tile/core/container/array.hpp @@ -19,6 +19,25 @@ namespace ck_tile { // array buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0}) // use make_array_with({...}) to construct an array with compatible behavior as old ck // TODO: manually added constructor same as old ck +/** + * @brief A fixed-size array container similar to std::array with additional utilities. + * + * This template class provides a lightweight fixed-size array with value semantics, + * supporting both host and device functionality for GPU programming. It includes + * specialized initialization methods and type punning capabilities. + * + * @tparam T_ The type of elements in the array + * @tparam N_ The fixed number of elements in the array + * + * @note This implementation provides additional features beyond std::array: + * - GPU compatibility via CK_TILE_HOST_DEVICE macros + * - Type punning via get_as() and set_as() methods + * - Various specialized access methods + * - Specialized initialization behaviors + * + * The initializer_list constructor fills remaining elements with the last value + * provided if the list size is smaller than N, which is different than std::array. + */ template struct array { @@ -142,6 +161,14 @@ struct array // empty Array +/// @brief Specialization of array container for zero elements. +/// +/// This is a specialization of the array container template for the case where the number of +/// elements is 0. It provides the same interface as the general array template, but with operations +/// appropriate for an empty array. +/// +/// @tparam T The type of elements stored in the array (not used in this specialization but +/// maintained for API consistency). template struct array { diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index fd02177e25..3700d348e7 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -396,11 +396,16 @@ struct tuple_array_impl }; } // namespace impl +template +CK_TILE_HOST_DEVICE constexpr auto generate_tuple_for(F&& f, sequence) +{ + return make_tuple(f(number{})...); +} + template CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number) { - return unpack([&f](auto&&... is) { return make_tuple(f(is)...); }, - typename arithmetic_sequence_gen<0, N, 1>::type{}); + return generate_tuple_for(f, make_index_sequence{}); } template diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index a4e8ca6a2b..b5da468319 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -530,7 +530,7 @@ CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x) } else { - if(x == 0x80) + if(x == SrcT(0x80)) { return fNeg0; } diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 6bdcb509b0..8176fe551c 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -487,6 +487,9 @@ struct log2e template constexpr T log2e_v = log2e::value; +template +constexpr T log2e_rcp_v = 1. / log2e::value; + CK_TILE_DEVICE float exp2(float x) { return exp2f(x); }; @@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp(double x) return exp(x); }; +template +CK_TILE_DEVICE T tanh_fast(T x) +{ + return type_convert((exp(2.0 * type_convert(x)) - 1.0) / + (exp(2.0 * type_convert(x)) + 1.0)); +}; + +template <> +CK_TILE_DEVICE float tanh_fast(float x) +{ + // float a = __builtin_amdgcn_sinh(x); + // float b = __builtin_amdgcn_cosh(x); + // float e = a * __builtin_amdgcn_rcpf(b); + // return e; + + float a = 2.0f * log2e_v * x; + a = __builtin_amdgcn_exp2f(a); + a = __builtin_amdgcn_rcpf(a + 1.0f); + a = 2 * a; + a = 1 - a; + return a; + + // float e, r, s, t, d; + // float a = x; + // s = abs(a); + // t = -log2e_v * 2.0f * s; + // e = __builtin_amdgcn_exp2f(t); + // d = e + 1.0f; + // r = __builtin_amdgcn_rcpf(d); + // r = e * (-r) + r; + // if (s < 4.997253418e-3f) r = a; + // union fipnr {float f; unsigned int i;}; + // fipnr r_; r_.f = r; + // fipnr a_; a_.f = a; + // { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; } + // return r; +}; + template CK_TILE_DEVICE T log(T x) { diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index bdcfbdd920..c2a093f1ab 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -5,11 +5,7 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/arch/arch.hpp" -#if __clang_major__ == 20 -#include "ck_tile/core/arch/amd_buffer_addressing_builtins.hpp" -#else #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#endif #include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/container/array.hpp" #include "ck_tile/core/numeric/integer.hpp" diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index b280a1725d..4601261197 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -18,32 +18,8 @@ namespace ck_tile { -template -CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(const tile_window_linear& tile_window, +template +CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, number = {}, bool_constant = {}) { @@ -51,35 +27,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}) -{ - return tile_window.load(dst_tile, number{}, bool_constant{}); -} - -template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, - const tile_window_linear& tile_window, + const TileWindow_& tile_window, number = {}, bool_constant = {}) { @@ -138,42 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile, } template -CK_TILE_DEVICE auto -async_load_tile_raw(LdsTileWindow_&& lds_tile, - const tile_window_with_static_distribution& tile_window, - number = {}, - bool_constant = {}, - bool_constant = {}) -{ - return tile_window.async_load_raw(lds_tile, - number{}, - bool_constant{}, - bool_constant{}); -} - -template CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile, - const tile_window_linear& tile_window, + const TileWindow_& tile_window, number = {}, bool_constant = {}, bool_constant = {}) diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp index 336793c5b1..656ce8d20d 100644 --- a/include/ck_tile/core/tensor/tensor_view.hpp +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -210,6 +210,27 @@ struct tensor_view bool_constant{}); } + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void + async_get_vectorized_elements_raw(remove_cvref_t* smem, + const TensorCoord& coord, + index_t coord_extra_offset, + index_t linear_offset, + bool_constant = {}) const + { + return buf_.template async_get_raw( + smem, + (coord.get_offset() + coord_extra_offset) / PackedSize, + linear_offset / PackedSize, + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + template CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, const tensor_descriptor& desc) { - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } template {}, number{}); - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } template @@ -458,7 +468,8 @@ make_naive_tensor_view_packed(DataType* p, auto desc = make_naive_tensor_descriptor_packed(lengths, number{}); - auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + auto buffer_view = + make_buffer_view(p, desc.get_element_space_size()); return tensor_view{buffer_view, desc}; } @@ -488,6 +499,7 @@ template +struct tile_scatter_gather +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + using PageIdxArray = remove_cvref_t; + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static_assert(NumCoord == 1); + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_scatter_gather::get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr tile_scatter_gather() = default; + + CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution, + const PageIdxArray& page_idx) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + page_idx_{page_idx}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0; + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + CK_TILE_DEVICE constexpr void + set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data) + { + bottom_tensor_view_.buf_.p_data_ = data; + } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + template + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const ATopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, number{}, bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + bool_constant{}); +#if 1 + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j / Traits::PackedSize]; + }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // TODO: currently async load only implemented in inline asm + template + CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + constexpr auto pre_nop_ = [&]() { + if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0) + return bool_constant{}; + else + return bool_constant{}; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_gather]; + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements_raw( + smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + number = {}, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + // printf("off %d\n", page_idx_[I0]); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number<0>{}]; + const auto page_offset = page_idx_[idx_gather]; + + // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n", + // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0); + + // read from distributed tensor + // vector_type_t vec; + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j); + vec_value.template get_as()(j / Traits::PackedSize) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // const vector_t vec_value = vec.template get_as().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, + page_offset, + vec_value, + bool_constant{}); + // printf("coord_offset:%d, scatter_offset:%d \n", + // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + BottomTensorIndex step_new = step; + step_new(HsGatherDim) = 0; + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step_new); + }); + } + + CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) + { + page_idx_ = new_idx; + + // static_for<0, 2, 1>{}([&](auto k0) { + // printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]); + // }); + } + CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin) + { + window_origin_ = new_window_origin; + +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_dstr_.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_dstr_), array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index(); + + bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0; + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + PageIdxArray page_idx_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_scatter_gather(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + const StaticPageIndexArray_& page_idx, + number = {}, + number = {}) +{ + return tile_scatter_gather, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + HsGatherDim, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution, page_idx}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const multi_index& origin, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + origin, + tile_distribution, + page_idx, + number{}); +} + +template +CK_TILE_DEVICE constexpr auto make_tile_scatter_gather( + const tile_window_with_static_lengths& tile_window, + const StaticTileDistribution& tile_distribution, + const StaticPageIndexArray& page_idx, + number = {}) +{ + return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(), + tile_window.get_window_lengths(), + tile_window.get_window_origin(), + tile_distribution, + page_idx, + number{}); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 3bb728df23..716b1f4ecb 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -1164,4 +1164,82 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a tile window with static distribution. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_distribution : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_distribution` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam NumCoord Number of coordinate dimensions. + */ +template +struct is_tile_window_with_static_distribution< + tile_window_with_static_distribution> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static distribution. + * + * Equivalent to `is_tile_window_with_static_distribution::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_distribution_v = + is_tile_window_with_static_distribution::value; + +/** + * @brief Type trait to determine if a type is a tile window with static lengths. + * + * Defaults to `false_type`. Specializations define when the trait evaluates to `true`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_with_static_lengths : std::false_type +{ +}; + +/** + * @brief Specialization for `tile_window_with_static_lengths` to evaluate to `true_type`. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + */ +template +struct is_tile_window_with_static_lengths< + tile_window_with_static_lengths> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a tile window with static lengths. + * + * Equivalent to `is_tile_window_with_static_lengths::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_with_static_lengths_v = + is_tile_window_with_static_lengths::value; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index 1e24e660f6..5ecaf5ca17 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -44,6 +44,7 @@ template struct tile_window_linear { + using BottomTensorView = remove_reference_t; using WindowLengths = remove_cvref_t; using TileDstr = remove_cvref_t; @@ -1215,4 +1216,49 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +/** + * @brief Type trait to determine if a type is a linear tile window. + * + * Defaults to `false_type`. Specialized to `true_type` for types that match + * `tile_window_linear<...>`. + * + * @tparam T The type to check. + */ +template +struct is_tile_window_linear : std::false_type +{ +}; + +/** + * @brief Specialization of `is_tile_window_linear` for `tile_window_linear`. + * + * Evaluates to `true_type` if the type is a `tile_window_linear` with the given template + * parameters. + * + * @tparam BottomTensorView_ Bottom tensor view type of the tile window. + * @tparam WindowLengths_ Static window lengths. + * @tparam StaticTileDistribution_ Tile distribution policy. + * @tparam LinearBottomDims_ Dimensions of the bottom tensor view that participate in linearization. + */ +template +struct is_tile_window_linear> : std::true_type +{ +}; + +/** + * @brief Helper variable template to check if a type is a linear tile window. + * + * Equivalent to `is_tile_window_linear::value`. + * + * @tparam T The type to check. + */ +template +inline constexpr bool is_tile_window_linear_v = is_tile_window_linear::value; + } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window_utils.hpp b/include/ck_tile/core/tensor/tile_window_utils.hpp index 71a72329f8..f8b232a7af 100644 --- a/include/ck_tile/core/tensor/tile_window_utils.hpp +++ b/include/ck_tile/core/tensor/tile_window_utils.hpp @@ -18,6 +18,13 @@ #pragma once namespace ck_tile { +template +CK_TILE_DEVICE void move_tile_window(TileWindow_& window, + const typename TileWindow_::BottomTensorIndex& step) +{ + window.move(step); +} + // input a lds store tile, extract some information from it // used to set m0 value for gfx9 serious template diff --git a/include/ck_tile/core/tensor/transpose_tile.hpp b/include/ck_tile/core/tensor/transpose_tile.hpp index f34efe5c2f..d917cd5bac 100644 --- a/include/ck_tile/core/tensor/transpose_tile.hpp +++ b/include/ck_tile/core/tensor/transpose_tile.hpp @@ -83,12 +83,14 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, constexpr index_t num_vec_in = vec_length_out; constexpr index_t num_vec_out = vec_length_in; - using InVec = array; - using OutVec = array; - // SFC constexpr auto scalars_per_access_arr = generate_array( - [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, + [&](auto i) { + if constexpr(vec_length_in == 1) + return 1; + else + return (i == y_dim_vec_in || i == y_dim_vec_out) ? y_lengths[i] : 1; + }, number{}); constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY); @@ -101,51 +103,90 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor, static_assert(num_access > 0, "wrong! num_access should be larger than 0"); - // in/out vectors to be transposed - thread_buffer in_vectors; - thread_buffer out_vectors; - - // loop over SFC and do transpose - static_for<0, num_access, 1>{}([&](auto iAccess) { - // data index [y0, y1, ...] in the order of input tensor - constexpr auto idx_y_start = SFC_Y::get_index(iAccess); - - // get input vectors - static_for<0, num_vec_in, 1>{}([&](auto i) { - constexpr auto idx_y_in = generate_tuple( - [&](auto ii) { - return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; - }, - number{}); - + if constexpr(num_vec_in == 1 || num_vec_out == 1) + { + // loop over SFC + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + constexpr auto idx_y_in = + generate_tuple([&](auto ii) { return idx_y_start[ii].value; }, number{}); constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); static_assert(in_offset % vec_length_in == 0); - - in_vectors(i).template get_as()(I0) = - in_tensor.get_thread_buffer() - .template get_as()[number{}]; - }); - - // transpose - transpose_vectors{}(in_vectors, out_vectors); - - // set output vectors - static_for<0, num_vec_out, 1>{}([&](auto i) { - constexpr auto idx_y_out_tmp = generate_array( - [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, - number{}); - + constexpr auto idx_y_out_tmp = + generate_array([&](auto ii) { return idx_y_start[ii].value; }, number{}); constexpr auto idx_y_out = container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); - constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); - static_assert(out_offset % vec_length_out == 0); + if constexpr(vec_length_in == 1) + { - out_tensor.get_thread_buffer().template set_as( - number{}, - out_vectors[i].template get_as()[I0]); + out_tensor.get_thread_buffer()[number{}] = + in_tensor.get_thread_buffer()[number{}]; + } + else + { + using Vec = array; + out_tensor.get_thread_buffer().template get_as( + number{}) = + in_tensor.get_thread_buffer().template get_as( + number{}); + } }); - }); + } + else + { + using InVec = array; + using OutVec = array; + + // in/out vectors to be transposed + thread_buffer in_vectors; + thread_buffer out_vectors; + + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_tuple( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); + + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer() + .template get_as()[number{}]; + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { + return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); + }); + }); + } } } // namespace detail diff --git a/include/ck_tile/core/utility/env.hpp b/include/ck_tile/core/utility/env.hpp index 5b0b7a9071..9b148b3e0b 100644 --- a/include/ck_tile/core/utility/env.hpp +++ b/include/ck_tile/core/utility/env.hpp @@ -202,3 +202,7 @@ void UpdateEnvVar(EnvVar, const std::string_view& val) } } // namespace ck_tile + +// environment variable to enable logging: +// export CK_TILE_LOGGING=ON or CK_TILE_LOGGING=1 or CK_TILE_LOGGING=ENABLED +CK_TILE_DECLARE_ENV_VAR_BOOL(CK_TILE_LOGGING) diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 006026470b..3f64eb28cd 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -280,7 +280,7 @@ struct FillMonotonicSeq template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = init_value_]() mutable { + std::generate(first, last, [=, *this, n = init_value_]() mutable { auto tmp = n; if constexpr(std::is_same_v) { @@ -315,7 +315,7 @@ struct FillStepRange template void operator()(ForwardIter first, ForwardIter last) const { - std::generate(first, last, [=, n = start_value_]() mutable { + std::generate(first, last, [=, *this, n = start_value_]() mutable { auto tmp = n; n += step_; if constexpr(IsAscending) @@ -364,6 +364,49 @@ struct FillConstant } }; +//---------------------------------------------------------------------------------------------- +/// @brief Transforms given input to fit 2:4 structured sparsity pattern so +/// every subgroup of 4 elements contain at most 2 non-zero elements +template +struct AdjustToStructuredSparsity +{ + size_t start{0}; + // masks represent all valid 2:4 structured sparsity permutations + // clang-format off + static constexpr int32_t masks[] = {0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + 0, 0, 0, 1, + 0, 0, 1, 0, + 0, 1, 0, 0, + 1, 0, 0, 0}; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::transform(first, last, first, [=, *this, index = start](T val) mutable { + auto tmp = val * masks[index % (sizeof(masks) / sizeof(int32_t))]; + index += 1; + + return type_convert(tmp); + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + template struct FillTrigValue { diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index 7e7dd03c6a..4c3aa2ba29 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -19,7 +19,6 @@ struct BatchedTransposeHostArgs index_t batch; index_t height; index_t width; - // index_t dim_blocks; index_t dim_stride; index_t dim_block_h; index_t dim_block_w; @@ -28,8 +27,10 @@ struct BatchedTransposeHostArgs template struct BatchedTransposeKernel { - using Pipeline = remove_cvref_t; - using Problem = remove_cvref_t; + + CK_TILE_DEVICE static index_t counter = 0; + using Pipeline = remove_cvref_t; + using Problem = remove_cvref_t; using Type = typename Problem::InputType; @@ -46,11 +47,11 @@ struct BatchedTransposeKernel using Kargs = BatchedTransposeKargs; using Hargs = BatchedTransposeHostArgs; - CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + CK_TILE_HOST static constexpr auto GridSize(const Hargs& host_args) { - size_t grid_size_x = (h.width + h.dim_block_w - 1) / h.dim_block_w; - size_t grid_size_y = (h.height + h.dim_block_h - 1) / h.dim_block_h; - size_t grid_size_z = h.batch; + size_t grid_size_x = (host_args.height + host_args.dim_block_h - 1) / host_args.dim_block_h; + size_t grid_size_y = (host_args.width + host_args.dim_block_w - 1) / host_args.dim_block_w; + size_t grid_size_z = host_args.batch; return dim3(grid_size_x, grid_size_y, grid_size_z); } @@ -70,58 +71,52 @@ struct BatchedTransposeKernel CK_TILE_DEVICE void operator()(Kargs kargs) const { + static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; + static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr ck_tile::index_t VectorSizeInput = Problem::VectorSizeInput; + static constexpr ck_tile::index_t VectorSizeOutput = Problem::VectorSizeOutput; - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); + const auto iDim = blockIdx.z; - static constexpr ck_tile::index_t kMPerThread = Problem::kMPerThread; - static constexpr ck_tile::index_t kNPerThread = Problem::kNPerThread; - - static_assert(kMPerThread == 1 && kNPerThread == 1); - - const auto iDim = blockIdx.z; const auto x_m_n = [&]() { const auto x_dram_naive = make_naive_tensor_view( static_cast(kargs.p_input) + iDim * kargs.dim_stride, make_tuple(kargs.height, kargs.width), make_tuple(kargs.width, 1), - number{}, // TODO thread load value + number{}, number<1>{}); return pad_tensor_view(x_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.y * kNPerBlock); - const auto y_n_m = [&]() { const auto y_dram_naive = make_naive_tensor_view( static_cast(kargs.p_output) + iDim * kargs.dim_stride, make_tuple(kargs.width, kargs.height), make_tuple(kargs.height, 1), - number{}, + number{}, number<1>{}); return pad_tensor_view(y_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); - auto x_block_window = - make_tile_window(x_m_n, - make_tuple(number{}, number{}), - {static_cast(iM * kMPerBlock), - static_cast(iN * kNPerBlock)}); + auto x_block_window = make_tile_window( + x_m_n, + make_tuple(number{}, number{}), + {static_cast(iM), static_cast(iN)}); - auto y_block_window = - make_tile_window(y_n_m, - make_tuple(number{}, number{}), - {static_cast(iN * kNPerBlock), - static_cast(iM * kMPerBlock)}); + auto y_block_window = make_tile_window( + y_n_m, + make_tuple(number{}, number{}), + {static_cast(iN), static_cast(iM)}); Pipeline{}(x_block_window, y_block_window); } diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index aa62333918..e815313c06 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -29,24 +29,18 @@ struct BatchedTransposePipeline { auto inp_win = make_tile_window(input_window, Policy::template MakeInputDistribution()); + + auto input_tile = load_tile(inp_win); + + auto output_tile = make_static_distributed_tensor( + Policy::template MakeOutputDistribution()); + + transpose_tile2d(output_tile, input_tile); + auto out_win = make_tile_window(out_window, Policy::template MakeOutputDistribution()); - auto x = load_tile(inp_win); // x->thread input_win->block - - auto y = make_static_distributed_tensor( - Policy::template MakeOutputDistribution()); - - constexpr auto span_2d_x = decltype(x)::get_distributed_spans(); - - sweep_tile_span(span_2d_x[number<0>{}], [&](auto idx0) { - sweep_tile_span(span_2d_x[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx1, idx0); - y(i_j_idx) = x(i_j_idx); - }); - }); - - store_tile(out_win, y); + store_tile(out_win, output_tile); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index 9953e8b8bf..dd9a6d79a8 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -14,31 +14,34 @@ struct BatchedTransposePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() { - using S = Problem; - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 1>>, - sequence<1, 2>, - sequence<2, 2>>{}); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::kMPerBlock; + constexpr index_t NPerBlock = Problem::kNPerBlock; + constexpr index_t VecLoadSize = Problem::VectorSizeInput; + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); } template CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() { - using S = Problem; - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 1>>, - sequence<2, 1>, - sequence<2, 2>>{}); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::kMPerBlock; + constexpr index_t NPerBlock = Problem::kNPerBlock; + constexpr index_t VecLoadSize = Problem::VectorSizeOutput; + + using TileEncodingPattern = + TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index af6b2d51aa..fd5ea004b6 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/core.hpp" -#include #include #define VectorLoadSize 16 @@ -12,11 +11,11 @@ namespace ck_tile { template + typename BlockTile, // Sequence<... + typename WarpTile, // Sequence<... + typename ThreadTile, + bool kPadM_ = false, + bool kPadN_ = false> // Sequence<... struct BatchedTransposeProblem { using InputType = remove_cvref_t; @@ -42,7 +41,7 @@ struct BatchedTransposeProblem static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr index_t AlignmentM = kPadM ? VectorLoadSize / sizeof(InputType) : 1; // TODO - static constexpr index_t AlignmentN = kPadN ? VectorLoadSize / sizeof(InputType) : 1; + static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType); + static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 12e53e13e6..6cc0fa8540 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -4,9 +4,9 @@ #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" -#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 0081edcb2e..9b8dde1905 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -22,23 +22,25 @@ template + bool isCTransposed_, + memory_operation_enum MemoryOperation_> struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t kMWave = kMWave_; - static constexpr index_t kNWave = kNWave_; - static constexpr index_t kMPerXdl = kMPerXdl_; - static constexpr index_t kNPerXdl = kNPerXdl_; - static constexpr index_t kKPerXdl = kKPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t kMWave = kMWave_; + static constexpr index_t kNWave = kNWave_; + static constexpr index_t kMPerXdl = kMPerXdl_; + static constexpr index_t kNPerXdl = kNPerXdl_; + static constexpr index_t kKPerXdl = kKPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; template @@ -49,20 +51,22 @@ struct CShuffleEpilogue using BDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = - std::conditional_t, ODataType, BDataType>; - using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t kMWave = Problem::kMWave; - static constexpr index_t kNWave = Problem::kNWave; - static constexpr index_t kMPerXdl = Problem::kMPerXdl; - static constexpr index_t kNPerXdl = Problem::kNPerXdl; - static constexpr index_t kKPerXdl = Problem::kKPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr index_t kMPerIteration = kMPerXdl * kMWave; - static constexpr index_t kNPerIteration = kNPerXdl * kNWave; + std::conditional_t, ADataType, BDataType>; + using CLayout = remove_cvref_t; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t kMWave = Problem::kMWave; + static constexpr index_t kNWave = Problem::kNWave; + static constexpr index_t kMPerXdl = Problem::kMPerXdl; + static constexpr index_t kNPerXdl = Problem::kNPerXdl; + static constexpr index_t kKPerXdl = Problem::kKPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr index_t kMPerIteration = kMPerXdl * kMWave; + static constexpr index_t kNPerIteration = kNPerXdl * kNWave; using WG = WarpGemmMfmaDispatcher + template CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem) { @@ -178,7 +180,7 @@ struct CShuffleEpilogue const auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution)); - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 6e290fe6d7..a2915f5c8f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,17 +15,21 @@ template + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; -template -struct DefaultGemm2DEpilogueProblem - : public Default2DEpilogueProblem + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> +struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; @@ -54,14 +65,13 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) { @@ -69,7 +79,7 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); } @@ -81,7 +91,7 @@ struct Default2DEpilogue } else { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } @@ -96,17 +106,22 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; - using WG = WarpGemmMfmaDispatcher +struct BlockFlatmmASmemBSmemCRegV1 +{ + using Problem = remove_cvref_t; + using BlockPolicy = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + CK_TILE_DEVICE static constexpr auto MakeCBlockTile() + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t NPerBlock = BlockGemmShape::kN; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); + + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + ABlockWindow& a_warp_windows, + BFlatBlockTensor& b_warp_tensor) const + { + constexpr index_t MPerBlock = BlockGemmShape::kM; + constexpr index_t KPerBlock = BlockGemmShape::kK; + + constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = + BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN)); + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // read A warp tensor from A block window + const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp new file mode 100644 index 0000000000..d5b062a1b3 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// Default policy for BlockGemmASmemBSmemCRegV1 +// Default policy class should not be templated, put template on member functions instead +template +struct BlockFlatmmASmemBSmemCRegV1CustomPolicy +{ + using AType = remove_cvref_t; + using BType = remove_cvref_t; + using CType = remove_cvref_t; + + using BlockWarps = remove_cvref_t; + + static constexpr index_t kMWarps = BlockWarps::at(number<0>{}); + static constexpr index_t kNWarps = BlockWarps::at(number<1>{}); + static constexpr index_t kKWarps = BlockWarps::at(number<2>{}); + + using WarpGemm = remove_cvref_t; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + return make_tuple(WarpGemm{}, kMWarps, kNWarps); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp new file mode 100644 index 0000000000..a9ed1519e6 --- /dev/null +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -0,0 +1,486 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" + +namespace ck_tile { + +struct FlatmmProblem +{ + CK_TILE_HOST FlatmmProblem() = default; + CK_TILE_HOST FlatmmProblem( + index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_) + : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_) + { + } + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; +}; + +struct FlatmmHostArgs : public FlatmmProblem +{ + CK_TILE_HOST FlatmmHostArgs() = default; + CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_, + const void* b_shuffle_ptr_, + void* c_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_) + : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_), + a_ptr(a_ptr_), + b_shuffle_ptr(b_shuffle_ptr_), + c_ptr(c_ptr_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t k_batch; +}; + +template +struct FlatmmKernel +{ + using TilePartitioner = remove_cvref_t; + using FlatmmPipeline = remove_cvref_t; + using BlockGemmShape = + remove_cvref_t; // TileFlatmmShape + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using CDataType = remove_cvref_t; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str, FlatmmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + struct FlatmmKernelArgs + { + const void* a_ptr; + const void* b_shuffle_ptr; + void* c_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + index_t stride_C; + index_t k_batch; + }; + + CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs) + { + return FlatmmKernelArgs{hostArgs.a_ptr, + hostArgs.b_shuffle_ptr, + hostArgs.c_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_A, + hostArgs.stride_B, + hostArgs.stride_C, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead; + } + else if constexpr(std::is_same_v) + { + a_k_split_offset = k_id * KRead * kargs.stride_A; + } + + if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead * kargs.stride_B; + } + else if constexpr(std::is_same_v) + { + b_k_split_offset = k_id * KRead; + } + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = KRead; + } + else + { + splitted_k = kargs.K - KRead * (kargs.k_batch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const FlatmmKernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + std::cerr << "Conditions not met for Kbatch >1 !" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0) + { + std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false) + { + std::cerr << "Can't support K that is not a multiple of KPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0) + { + std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl; + return false; + } + } + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false) + { + std::cerr << "Can't support N that is not a multiple of NPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false) + { + std::cerr << "Can't support M that is not a multiple of MPerBlock" + " without padding!" + << std::endl; + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; + return false; + } + } + return true; + } + + template + CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k / + BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + const auto& b_flat_tensor_view = [&]() { + return make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + }(); + + // TODO: enable vector write for C in ColMajor + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + const auto& b_flat_tensor_view = views.at(I1); + + // TODO vector write in for C in ColMajor + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I2); + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& a_pad_view = views.at(I0); + const auto& b_flat_pad_view = views.at(I1); + const auto& c_pad_view = views.at(I2); + + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + const auto& b_flat_block_window = + make_tile_window(b_flat_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / BlockGemmShape::WarpTile::at(idxN)), 0}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_flat_block_window, c_block_window); + } + + CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr, + const BDataType* b_flat_ptr, + CDataType* c_ptr, + void* smem_ptr, + const FlatmmKernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_flat_block_window = gemm_tile_windows.at(I1); + const auto& c_block_tile = FlatmmPipeline{}.template operator()( + a_block_window, b_flat_block_window, num_loop, smem_ptr); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I2); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr); + } + + CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const + { + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + // options + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_flat_ptr = static_cast(kargs.b_shuffle_ptr) + + splitk_batch_offset.b_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); + + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp new file mode 100644 index 0000000000..cbd20a6ea3 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -0,0 +1,430 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" + +namespace ck_tile { + +template +struct FlatmmPipelineAGmemBGmemCRegV1 +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; } + static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV1", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) || defined(USING_MFMA_32x32x16) + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; +#endif +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + +#elif defined(USING_MFMA_32x32x16) + static_for<0, + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA +#endif + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + constexpr index_t MWarp = config.template at<1>(); + constexpr index_t NWarp = config.template at<2>(); + + constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + auto a_warp_window_tmp = make_tile_window( + a_lds_gemm_window.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_lds_gemm_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + // Acc register tile + auto c_block_tile = block_flatmm.MakeCBlockTile(); + + // prefetch + // global read 0 + auto a_block_tile = load_tile(a_copy_dram_window); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_2; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + { + // move to 1 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDistribution()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + block_sync_lds(); + } + + index_t iCounter = num_loop / 2 - 1; + while(iCounter > 0) + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // LDS write i + 1 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + HotLoopScheduler(); + block_sync_lds(); + + // iCounter--; + + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // move to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // LDS write i + 1 + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + HotLoopScheduler(); + block_sync_lds(); + + iCounter--; + } + + // tail + { + // global read i + 1 + a_block_tile = load_tile(a_copy_dram_window); + + // GEMM i + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor); + + block_sync_lds(); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // move to i + 2 + // move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + // move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + HotLoopScheduler(); + block_sync_lds(); + + // GEMM num_loop - 1 + block_flatmm(c_block_tile, a_warp_windows, b_warp_tensor_2); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp new file mode 100644 index 0000000000..1a1b729394 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; +#if defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; +#elif defined(USING_MFMA_32x32x16) + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; +#endif +/*xor*/ +#if 0 + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() >= (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using BDataType = remove_cvref_t; + + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t KBPerLoad = + Problem::VectorLoadSize / sizeof(BDataType); // dwordx4 load B elem cnt + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(TileShape::idxN); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<1, 2>>, // which direction + tuple, sequence<2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size >= (K2 * M0)) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() + { + using AccDataType = float; + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + using BlockFlatmmPolicy = + BlockFlatmmASmemBSmemCRegV1CustomPolicy; + return BlockFlatmmASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp new file mode 100644 index 0000000000..551d390ec6 --- /dev/null +++ b/include/ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct TileFlatmmShape +{ + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr auto idxM = number<0>{}; + static constexpr auto idxN = number<1>{}; + static constexpr auto idxK = number<2>{}; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(idxM); + static constexpr index_t kN = BlockTile::at(idxN); + static constexpr index_t kK = BlockTile::at(idxK); + + static constexpr index_t flatNPerWarp = BlockWarps::at(idxN); + static constexpr index_t flatKPerWarp = WarpTile::at(idxK) * WarpTile::at(idxN); + static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(idxK); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_flatmm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(idxM), BlockWarps::at(idxN), BlockWarps::at(idxK)), + concat('x', (WarpTile::at(idxM)), WarpTile::at(idxN), WarpTile::at(idxK))); + // clang-format on + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index a28b63f813..ac6ef9cae3 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -9,12 +9,16 @@ #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp" #include "ck_tile/ops/fmha/block/page_block_navigator.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp" diff --git a/include/ck_tile/ops/fmha/block/variants.hpp b/include/ck_tile/ops/fmha/block/variants.hpp new file mode 100644 index 0000000000..90fc5656fc --- /dev/null +++ b/include/ck_tile/ops/fmha/block/variants.hpp @@ -0,0 +1,274 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include +#include + +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 + +#ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT +#define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH +#endif + +namespace ck_tile { + +template +struct StandardAttentionParams +{ + __device__ __host__ StandardAttentionParams(const ImplMask& impl_mask_, float sm_scale_) + : impl_mask(impl_mask_), sm_scale(sm_scale_) + { + } + + const ImplMask& impl_mask; + float sm_scale; +}; + +template +struct LogitsSoftCapParams +{ + __device__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = __builtin_amdgcn_rcpf(logits_soft_cap); + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __host__ + LogitsSoftCapParams(const ImplMask& impl_mask_, float sm_scale_, float logits_soft_cap_) + : impl_mask(impl_mask_), sm_scale(sm_scale_), logits_soft_cap(logits_soft_cap_) + { + if(0.f < logits_soft_cap) + { + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap_rcp = 0.f; + } + + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + __device__ __host__ LogitsSoftCapParams(const ImplMask& impl_mask_, + float sm_scale_, + float logits_soft_cap_, + float logits_soft_cap_rcp_) + : impl_mask(impl_mask_), + sm_scale(sm_scale_), + logits_soft_cap(logits_soft_cap_), + logits_soft_cap_rcp(logits_soft_cap_rcp_) + { + // move computation here to prevent compiler from generating inefficient instruction + // sequence + if constexpr(UseExp2) + { + logits_soft_cap = log2e_v * logits_soft_cap; + logits_soft_cap_rcp = sm_scale * log2e_rcp_v * logits_soft_cap_rcp; + } + } + + const ImplMask& impl_mask; + float sm_scale; + float logits_soft_cap; + float logits_soft_cap_rcp; +}; + +struct StandardAttention +{ + __device__ __host__ StandardAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +template +struct LogitsSoftCap +{ + __device__ __host__ LogitsSoftCap() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(UseExp2) + { + return q; + } + else + { + return type_convert(q) * params.sm_scale; + } + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return params.sm_scale * type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +constexpr uint32_t CUSTOM_MASK = 1U; +constexpr uint32_t SLIDING_WINDOW = 2U; +constexpr uint32_t LOGITS_SOFT_CAP = 4U; +constexpr uint32_t ALIBI = 8U; + +template +struct ComposedAttention +{ + static constexpr bool use_exp2 = UseExp2; + + static constexpr bool use_logits_soft_cap = (VARIANT_CODE & LOGITS_SOFT_CAP) != 0; + + __device__ __host__ ComposedAttention() = default; + + template + __device__ __forceinline__ T QueryTransform(const Params& params, T q) const + { + if constexpr(use_logits_soft_cap && UseExp2) + { + return q; + } + return type_convert(q) * params.sm_scale; + } + + /// NOTICE: For better performance, we simpliy transform thread buffer without calculating + /// qo_idx/kv_idx. + template + __device__ __forceinline__ T LogitsTransform(const Params& params, + T logits, + [[maybe_unused]] uint32_t batch_idx, + /*uint32_t qo_idx, uint32_t kv_idx,*/ + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + if constexpr(use_logits_soft_cap) + { + if constexpr(UseExp2) + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanh_fast(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return params.sm_scale * type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + else + { +#if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH + return params.logits_soft_cap * + tanhf(type_convert(logits) * params.logits_soft_cap_rcp); +#elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN + return type_convert(logits) * + rcp(1.f + + abs(type_convert(logits) * params.logits_soft_cap_rcp)); +#endif + } + } + return logits; + } + + template + __device__ __forceinline__ bool LogitsMask(const Params& params, + [[maybe_unused]] uint32_t batch_idx, + uint32_t qo_idx, + uint32_t kv_idx, + [[maybe_unused]] uint32_t qo_head_idx, + [[maybe_unused]] uint32_t kv_head_idx) const + { + return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp new file mode 100644 index 0000000000..ba327ee511 --- /dev/null +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -0,0 +1,1134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct FmhaBatchPrefillWithPagedKVCacheKernel +{ + using FmhaPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu; + + using QDataType = ck_tile::remove_cvref_t; + using KDataType = ck_tile::remove_cvref_t; + using VDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using RandValOutputDataType = + ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; + static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; + static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = FmhaMask::IsMasking; + + static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { + // sync with generate.py + // clang-format off + using bfs = typename FmhaPipeline::BlockFmhaShape; + using g0br = typename bfs::Gemm0BlockWarps; + using g1br = typename bfs::Gemm1BlockWarps; + using g0wt = typename bfs::Gemm0WarpTile; + using g1wt = typename bfs::Gemm1WarpTile; + #define _SS_ std::string + #define _TS_ std::to_string + auto pn = [&] () { + std::string n; + if (kPadSeqLenQ) n += "s"; + if (kPadSeqLenK) n += "sk"; + if (kPadHeadDimQ) n += "d"; + if (kPadHeadDimV) n += "dv"; + return n.empty() ? n : std::string("p") + n; }(); + return + _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s::name) + + "_" + (kIsGroupMode ? "group" : "batch") + "_" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + + _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" + + "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" + + "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" + + "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); + #undef _SS_ + #undef _TS_ + // clang-format on + } + + template // to avoid duplicated base class prblem, introduce an template + // arg + struct FmhaFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct FmhaFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head_q; + // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k + // if this param is larger than 1, indicate MQA/GQA case + ck_tile::index_t nhead_ratio_qk; + + int32_t num_total_pages; + const int32_t* kv_indptr; + const int32_t* kv_page_indices; +#if 0 // we assume page_block_size=1 for now + const int32_t* kv_last_page_lens; + ck_tile::index_t page_block_size; +#else + static constexpr ck_tile::index_t page_block_size = 1; +#endif + + float scale_s; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + }; + + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + + struct FmhaFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct FmhaFwdBatchModeBiasKargs : FmhaFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + + struct FmhaFwdMaskKargs + { + // ck_tile::index_t window_size_left, window_size_right; + ck_tile::index_t window_size_left, window_size_right; + ck_tile::GenericAttentionMaskEnum mask_type; + }; + + struct FmhaFwdFp8StaticQuantKargs + { + float scale_p; + float scale_o; + }; + + struct FmhaFwdCommonLSEKargs + { + void* lse_ptr = nullptr; + ck_tile::index_t nhead_stride_lse = 0; + ck_tile::index_t batch_stride_lse = 0; + }; + + struct FmhaFwdDropoutSeedOffset + { + template + union ValueOrPointer + { + T val; + const T* ptr; + }; + + ValueOrPointer drop_seed; + ValueOrPointer drop_offset; + bool is_drop_seed_offset_from_host; + }; + + struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.val = seed; + this->drop_offset.val = offset; + this->is_drop_seed_offset_from_host = true; + } + + void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed.ptr = seed_ptr; + this->drop_offset.ptr = offset_ptr; + this->is_drop_seed_offset_from_host = false; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + bool is_store_randval = false; + void* rand_val_ptr = nullptr; + + ck_tile::index_t stride_randval = 0; + ck_tile::index_t nhead_stride_randval = 0; + }; + + struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs + { + ck_tile::index_t batch_stride_randval = 0; + }; + + struct FmhaFwdBatchModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct FmhaFwdGroupModeKargs + : FmhaFwdCommonKargs, + std::conditional_t>>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seqstart_q_ptr; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + }; + + using Kargs = std::conditional_t; + + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + ck_tile::index_t seqlen_q, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, +#if 0 // we assume page_block_size=1 for now + const void* kv_last_page_lens, + ck_tile::index_t page_block_size, +#endif + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_randval, + ck_tile::index_t batch_stride_lse, + ck_tile::index_t batch_stride_o, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen_q, + -1, + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_total_pages, + reinterpret_cast(kv_indptr), + reinterpret_cast(kv_page_indices), +#if 0 // we assume page_block_size=1 for now + reinterpret_cast(kv_last_page_lens), + page_block_size, +#endif +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + kargs.batch_stride_lse = batch_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.batch_stride_randval = batch_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* rand_val_ptr, + void* lse_ptr, + void* o_ptr, + const void* seqstart_q_ptr, + ck_tile::index_t hdim_q, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head_q, + ck_tile::index_t nhead_ratio_qk, + int32_t num_total_pages, + const void* kv_indptr, + const void* kv_page_indices, +#if 0 // we assume page_block_size=1 for now + const void* kv_last_page_lens, + ck_tile::index_t page_block_size, +#endif + float scale_s, + float scale_p, + float scale_o, + float logits_soft_cap, + ck_tile::index_t stride_q, + ck_tile::index_t stride_k, + ck_tile::index_t stride_v, + ck_tile::index_t stride_bias, + ck_tile::index_t stride_randval, + ck_tile::index_t stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_randval, + ck_tile::index_t nhead_stride_lse, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t window_size_left, + ck_tile::index_t window_size_right, + ck_tile::index_t mask_type, + float p_drop, + bool s_randval, + std::variant, std::pair> + drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + -1, // + hdim_q, + hdim_v, + num_head_q, + nhead_ratio_qk, + num_total_pages, + reinterpret_cast(kv_indptr), + reinterpret_cast(kv_page_indices), +#if 0 // we assume page_block_size=1 for now + reinterpret_cast(kv_last_page_lens), + page_block_size, +#endif +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static_cast(scale_s * ck_tile::log2e_v<>), +#else + scale_s, +#endif + stride_q, + stride_k, + stride_v, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for lse + {}, // placeholder for fp8_static_quant args + {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap + reinterpret_cast(seqstart_q_ptr), + batch_stride_k, + batch_stride_v}; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size_left = window_size_left; + kargs.window_size_right = window_size_right; + kargs.mask_type = static_cast(mask_type); + } + if constexpr(kStoreLSE) + { + kargs.lse_ptr = lse_ptr; + kargs.nhead_stride_lse = nhead_stride_lse; + } + if constexpr(kDoFp8StaticQuant) + { + kargs.scale_p = scale_p; + kargs.scale_o = scale_o; + } + if constexpr(kHasDropout) + { + if(drop_seed_offset.index() == 0) // seed & offset come from host + { + const auto& [seed, offset] = std::get<0>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + else // seed & offset come from device + { + const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, + reinterpret_cast(seed_ptr), + reinterpret_cast(offset_ptr)); + } + + kargs.rand_val_ptr = rand_val_ptr; + kargs.stride_randval = stride_randval; + kargs.nhead_stride_randval = nhead_stride_randval; + kargs.is_store_randval = s_randval; + } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } + + return kargs; + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_q_, + ck_tile::index_t hdim_v_) + { + if constexpr(kIsGroupMode) + { + // TODO: this may need tuning + return dim3(nhead_, + batch_size_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1)); + } + else + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + nhead_, + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + if constexpr(kIsGroupMode) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.z; + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + else + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_randval = 0; + long_index_t batch_offset_lse = 0; + long_index_t batch_offset_o = 0; + + const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch]; +#if 0 // we assume page_block_size=1 for now + const int32_t last_page_len = kargs.kv_last_page_lens[i_batch]; +#endif + if constexpr(kIsGroupMode) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; + + batch_offset_q = query_start * kargs.stride_q; + + kargs.kv_page_indices += kargs.kv_indptr[i_batch]; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = query_start * kargs.stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = query_start; + } + if constexpr(kHasDropout) + { + batch_offset_randval = query_start * kargs.stride_randval; + } + batch_offset_o = query_start * kargs.stride_o; + + // get real # queries & # keys under group mode + kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen_q <= i_m0) + { + return; + } + +#if 0 // we assume page_block_size=1 for now + kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; +#else + kargs.seqlen_k = num_page_blocks; +#endif + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + + kargs.kv_page_indices += kargs.kv_indptr[i_batch]; + + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + if constexpr(kStoreLSE) + { + batch_offset_lse = static_cast(i_batch) * kargs.batch_stride_lse; + } + if constexpr(kHasDropout) + { + batch_offset_randval = + static_cast(i_batch) * kargs.batch_stride_randval; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + +#if 0 // we assume page_block_size=1 for now + kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len; +#else + kargs.seqlen_k = num_page_blocks; +#endif + } + + // for simplicity, batch stride we just modify the pointer + const QDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const KDataType* k_ptr = + reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k; + const VDataType* v_ptr = + reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_q), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + if constexpr(FmhaPipeline::kQLoadOnce) + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + q_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + k_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = transform_tensor_view( + v_dram_naive, + make_tuple( + make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + return pad_tensor_view( + v_dram_transposed, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size), + make_tuple(kargs.stride_v, 1), + number{}, + number<1>{}); + + constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false; + return pad_tensor_view( + v_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(FmhaPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, make_tuple(number{}, number{}), {0, 0}); + + auto v_dram_window = + make_tile_window(v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + // lse + auto lse_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto lse_dram_window_lengths = make_tuple(number{}); + if constexpr(kStoreLSE) + { + LSEDataType* lse_ptr = + reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse; + + const auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view( + lse_ptr, + make_tuple(kargs.seqlen_q), + make_tuple(1), + number<1>{}, + number<1>{}); + + return pad_tensor_view( + lse_dram_naive, lse_dram_window_lengths, sequence{}); + }(); + + return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0}); + } + else + { + return make_null_tile_window(lse_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head_q, + kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val + : *kargs.drop_seed.ptr, + kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val + : *kargs.drop_offset.ptr, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + kargs.is_store_randval}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto randval_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto randval_dram_window_lengths = + make_tuple(number{}, number{}); + if constexpr(kHasDropout) + { + RandValOutputDataType* rand_val_ptr = + reinterpret_cast(kargs.rand_val_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_randval + + batch_offset_randval; + + const auto randval_dram = [&]() { + const auto randval_dram_naive = + make_naive_tensor_view( + rand_val_ptr, + make_tuple(kargs.seqlen_q, kargs.seqlen_k), + make_tuple(kargs.stride_randval, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(randval_dram_naive, + randval_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(randval_dram_window_lengths); + } + }(); + + FmhaMask mask = [&]() { + if constexpr(kHasMask) + return ck_tile::make_generic_attention_mask_from_lr_window( + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT); + else + return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; + }(); + + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + + auto o_acc_tile = [&]() { + if constexpr(kDoFp8StaticQuant) + { + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } + else + { + return FmhaPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + randval_dram_window, + lse_dram_window, + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + kargs.kv_page_indices, + kargs.stride_k, + kargs.stride_v, + dropout); + } + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen_q, kargs.hdim_v), + make_tuple(kargs.stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view( + o_dram_naive, + make_tuple(number{}, number{}), + sequence{}); + }(); + + auto o_dram_window = + make_tile_window(o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 1202524950..a4b3765455 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" #include #include @@ -47,11 +48,13 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy; @@ -94,7 +97,7 @@ struct FmhaFwdKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ); #undef _SS_ #undef _TS_ @@ -139,6 +142,28 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o; }; + struct FmhaFwdLogitsSoftCapKargs + { + FmhaFwdLogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct FmhaFwdCommonBiasKargs { const void* bias_ptr = nullptr; @@ -242,7 +267,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -260,7 +286,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -269,6 +296,13 @@ struct FmhaFwdKernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template CK_TILE_HOST static constexpr std::enable_if_t MakeKargsImpl(const void* q_ptr, @@ -287,6 +321,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -343,6 +378,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap batch_stride_q, batch_stride_k, batch_stride_v, @@ -398,6 +434,10 @@ struct FmhaFwdKernel kargs.batch_stride_randval = batch_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -421,6 +461,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -465,6 +506,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -512,6 +554,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -556,6 +599,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -603,6 +647,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -652,6 +697,7 @@ struct FmhaFwdKernel {}, // placeholder for lse {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; @@ -703,6 +749,10 @@ struct FmhaFwdKernel kargs.nhead_stride_randval = nhead_stride_randval; kargs.is_store_randval = s_randval; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -727,6 +777,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -765,6 +816,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -806,6 +858,7 @@ struct FmhaFwdKernel float scale_s, float scale_p, float scale_o, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -844,6 +897,7 @@ struct FmhaFwdKernel scale_s, scale_p, scale_o, + logits_soft_cap, stride_q, stride_k, stride_v, @@ -1307,6 +1361,21 @@ struct FmhaFwdKernel } }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk}; + auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { @@ -1328,6 +1397,9 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } @@ -1342,6 +1414,9 @@ struct FmhaFwdKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 143abe8048..63011d2ba9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -6,6 +6,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + #include #include @@ -43,14 +45,15 @@ struct FmhaFwdSplitKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV; static constexpr bool kMergeNumHeadGroupsSeqLenQ = FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ; - - using FmhaMask = ck_tile::remove_cvref_t; + using AttentionVariant = ck_tile::remove_cvref_t; + using FmhaMask = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static_assert(!kMergeNumHeadGroupsSeqLenQ || @@ -95,8 +98,8 @@ struct FmhaFwdSplitKVKernel "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) + - (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + + (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" ); #undef _SS_ #undef _TS_ @@ -150,6 +153,28 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t split_stride_o_acc; }; + struct LogitsSoftCapKargs + { + LogitsSoftCapKargs() = default; + + void init_logits_soft_cap(float logits_soft_cap_) + { + if(0 < logits_soft_cap_) + { + logits_soft_cap = logits_soft_cap_; + logits_soft_cap_rcp = 1.f / logits_soft_cap; + } + else + { + logits_soft_cap = 0.f; + logits_soft_cap_rcp = 0.f; + } + } + + float logits_soft_cap; + float logits_soft_cap_rcp; + }; + struct CommonBiasKargs { const void* bias_ptr = nullptr; @@ -207,7 +232,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t + std::conditional_t, + std::conditional_t> { const int32_t* seqlen_k_ptr; @@ -229,7 +255,8 @@ struct FmhaFwdSplitKVKernel EmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { const int32_t* seqstart_q_ptr; const int32_t* seqstart_k_ptr; @@ -243,6 +270,13 @@ struct FmhaFwdSplitKVKernel using Kargs = std::conditional_t; + struct BlockIndices + { + ck_tile::index_t batch_idx; + ck_tile::index_t qo_head_idx; + ck_tile::index_t kv_head_idx; + }; + template __host__ static constexpr std::enable_if_t MakeKargs(const void* q_ptr, @@ -268,6 +302,7 @@ struct FmhaFwdSplitKVKernel const void* cache_batch_idx, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -324,6 +359,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table or cache_batch_idx + {}, // placeholder for logits_soft_cap reinterpret_cast(seqlen_k_ptr), batch_stride_q, batch_stride_k, @@ -363,6 +399,10 @@ struct FmhaFwdSplitKVKernel { kargs.cache_batch_idx = reinterpret_cast(cache_batch_idx); } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -392,6 +432,7 @@ struct FmhaFwdSplitKVKernel bool is_gappy, float scale_s, float scale_p, + float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, @@ -444,6 +485,7 @@ struct FmhaFwdSplitKVKernel {}, // placeholder for mask {}, // placeholder for fp8_static_quant args {}, // placeholder for paged-block table + {}, // placeholder for logits_soft_cap reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr), @@ -478,6 +520,10 @@ struct FmhaFwdSplitKVKernel kargs.page_block_size = page_block_size; kargs.is_gappy = is_gappy; } + if constexpr(kHasLogitsSoftCap) + { + kargs.init_logits_soft_cap(logits_soft_cap); + } return kargs; } @@ -563,7 +609,7 @@ struct FmhaFwdSplitKVKernel } if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { - batch_offset_bias = query_start * kargs.stride_bias + key_start; + batch_offset_bias = query_start * kargs.stride_bias; } batch_offset_lse_acc = query_start; @@ -968,6 +1014,21 @@ struct FmhaFwdSplitKVKernel } }(); + AttentionVariant variant; + const auto variant_params = [&] { + if constexpr(kHasLogitsSoftCap) + { + return ck_tile::LogitsSoftCapParams{ + mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp}; + } + else + { + return ck_tile::StandardAttentionParams{mask, kargs.scale_s}; + } + }(); + + BlockIndices block_indices{i_batch, i_nhead, i_nhead_k}; + auto o_acc_tile = [&, i_split_ = i_split]() { if constexpr(kDoFp8StaticQuant) { @@ -991,6 +1052,9 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } @@ -1008,6 +1072,9 @@ struct FmhaFwdSplitKVKernel mask, position_encoding, kargs.scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp new file mode 100644 index 0000000000..e07cf1c94e --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" + +namespace ck_tile { + +// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future) +template +struct BlockFmhaBatchPrefillPipelineQRKSVSAsync +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using VDataType = remove_cvref_t; + using SaccDataType = remove_cvref_t; + using SMPLComputeDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using RandValOutputDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using OaccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; + using FmhaMask = remove_cvref_t; + + using BlockFmhaShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kN1 = BlockFmhaShape::kN1; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x) + // only need special care about seq_k padding (oob need set -INF of p instead of zero) + static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && + Problem::kPadHeadDimV == true); + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + static constexpr index_t kAlignmentO = Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + +#if CK_TILE_FMHA_FWD_FAST_EXP2 + static constexpr auto R_LOG2E = 1.0 / log2e_v; +#endif + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + if constexpr(BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout) + { + return 1; + } + + if constexpr(kQKHeaddim <= 32) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 64) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 2; + else + return 3; + } + else if constexpr(kQKHeaddim <= 128) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 192) + { + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim <= 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_async"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const KElementFunction& /*k_element_func*/, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, + LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile + const LSEElementFunction& lse_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + DropoutType& dropout) const + { + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto LdsSeq = Policy::template GetLdsBufferSequence(); + + // K tile in LDS + auto k_lds_ptr = reinterpret_cast(smem_ptr); + auto k_lds_store = generate_tuple( + [&](auto i_buf) { + return make_tile_window( + make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor(i_buf)), + Policy::template MakeKLdsStoreBlockDescriptor(i_buf).get_lengths(), + {0, 0, 0}); + }, + number{}); + + auto k_lds_Load_view = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor()); + + auto k_lds_load = + make_tile_window(k_lds_Load_view, + Policy::template MakeKLdsLoadBlockDescriptor().get_lengths(), + {0, 0}); + + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(smem_ptr), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + q_dram_window.init_raw(); + + // TODO: we use async Copy for K, which is inline asm + // a side effect is we have to use inline asm for q as well + auto q = decltype(load_tile(q_dram_window)){}; + // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q. + // however, q would be cleared in the constructor of static distributed tensor + // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch + load_tile_raw(q, q_dram_window); + __builtin_amdgcn_sched_barrier(0); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_max = [](auto e0, auto e1) { return max(e0, e1); }; + const auto f_sum = [](auto e0, auto e1) { return e0 + e1; }; + + // infer Sacc, S, P, M, L, Oacc type + using SBlockTileType = decltype(cast_tile(s_acc)); + + using MLBlockTileType = decltype(block_tile_reduce( + SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0})); + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + auto m = MLBlockTileType{}; + auto l = MLBlockTileType{}; + + clear_tile(o_acc); + set_tile(m, -numeric::infinity()); + clear_tile(l); + + __builtin_amdgcn_sched_barrier(0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) + { + if(num_total_loop <= 0) + { + if constexpr(kStoreLSE) + { + auto lse = + make_static_distributed_tensor(m.get_tile_distribution()); + + set_tile(lse, -numeric::infinity()); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) + // otherwise will have compute error(maybe compiler bug?) + + // Note: here occ are all cleard, return it + return o_acc; + } + __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check + } + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dist = Policy::template MakeKDramTileDistribution(); + auto k_coord = k_dist.calculate_index(); + using KDstrEncode = typename decltype(k_dist)::DstrEncode; + constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; + statically_indexed_array k_offsets; + static_for<0, NRepeat, 1>{}([&](auto n0) { + k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; + }); + auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + k_dist, + k_offsets); // K DRAM tile window for + k_dram_window.init_raw(); + constexpr auto k_oob_ck = bool_constant{}; + constexpr auto k_pre_np = [&]() { + if constexpr(kPadSeqLenK && + (BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + (BiasEnum != BlockAttentionBiasEnum::NO_BIAS && kHasDropout))) + return bool_constant{}; + else + return bool_constant{}; + }(); + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto randval_dram_window = dropout.template MakeRandvalDramWindow( + randval_dram_block_window_tmp, seqlen_k_start); + + auto v_dist = Policy::template MakeVDramTileDistribution(); + auto v_coord = v_dist.calculate_index(); + const auto VPageIndexDim = I1; + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; + statically_indexed_array v_offsets; + (void)stride_k; + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + + auto v_dram_window = + make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + v_dist, + v_offsets, + VPageIndexDim); + + // prefetch K tile + async_load_tile_raw( + k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + __builtin_amdgcn_sched_barrier(0); + + buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer()); + (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 + // auto q_tile = q; // tile_elementwise_in(q_element_func, q); + + index_t i_total_loops = 0; + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + + static_assert(1 <= k0_loops); + static_assert(1 <= k1_loops); + // main loop + do + { + // STAGE 1, QK gemm + clear_tile(s_acc); // initialize C + if constexpr(k0_loops > 1) + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + async_load_tile_raw(k_lds_store(number{})>{}), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + if constexpr(i_k0 < k0_loops - 1) + move_tile_window(k_dram_window, {0, kK0}); + + async_load_fence(k_dram_window.get_num_of_access()); + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + gemm_0(s_acc, + get_slice_tile( + q, sequence<0, i_k0 * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + }); + } + + // TODO: this to fix a bug when loop smaller than 2, + // the following fence/barrier will be scheduled inside 1st loop + if constexpr(k0_loops <= 2) + __builtin_amdgcn_sched_barrier(0); + + async_load_fence(); + __builtin_amdgcn_s_barrier(); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant{}); + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + + __builtin_amdgcn_sched_barrier(0); + { // tail + gemm_0( + s_acc, + get_slice_tile( + q, sequence<0, (k0_loops - 1) * kK0>{}, sequence{}), + get_slice_tile(k_lds_load, + sequence<(LdsSeq.at(number{})) * kN0, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); + } + __builtin_amdgcn_sched_barrier(1); + + // STAGE 2, scale_s, add bias, mask, softmax + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + x += type_convert(bias_element_func(y)); +#else + x += log2e_v * + type_convert(bias_element_func(y)); +#endif + }, + s_acc, + bias_tile); + } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } + } + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(kPadSeqLenK || FmhaMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), + k_origin.at(number<0>{}), + number{}, + number{}); + + if(need_perpixel_check) + { + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + } + } + + const auto s = cast_tile(s_acc); // S{j} + auto m_local = block_tile_reduce( + s, + sequence<1>{}, + f_max, + -numeric::infinity()); // m_local = rowmax(S{j}) + block_tile_reduce_sync(m_local, f_max, bool_constant{}); + + const auto m_old = m; // m{j-1} + tile_elementwise_inout( + [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + + auto p_compute = make_static_distributed_tensor( + s.get_tile_distribution()); // Pcompute{j} + + __builtin_amdgcn_sched_barrier(0x7F); + // store & prefetch next v, after the max reduction + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + + store_tile( + v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = + get_slice_tile(v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store the prefetch + } + + if constexpr(k1_loops > 1) + { + move_tile_window( + v_dram_window, + {0, kK1}); // will have scratch if move this right after load_tile(v_dram)... + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = + page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + } + __builtin_amdgcn_sched_barrier(0); + + static const auto get_validated_m = [](SMPLComputeDataType raw_m) { + /// NOTICE: bias might be materialized mask including -inf values, need + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) + { + return raw_m == -numeric::infinity() + ? type_convert(0.f) + : raw_m; + } + else + { + return raw_m; + } + }; + + constexpr auto p_spans = decltype(p_compute)::get_distributed_spans(); + sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + auto row_max = scale_s * get_validated_m(m[i_idx]); +#endif + sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } + } +#else + p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); +#endif + }); + }); + + auto rowsum_p = block_tile_reduce( + p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) + + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + // l{j}, Oacc{j} + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + const auto tmp = [&]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } + } + }(); +#else + const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx])); +#endif + l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx]; + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + // FIXME: this use different equation from FA v2 paper, + // but produce correc result. + // Is the equation wrong? + o_acc(i_j_idx) *= tmp; + }); + }); + + if constexpr(kHasDropout) + { + auto randval_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + dropout.template Run( + randval_ptr, + seqlen_k_start + i_total_loops * kN0, + p_compute, + randval_dram_window); + } + + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1) + { + v_buf = load_tile( + v_dram_window, number<-1>{}, bool_constant{}); // load next v_buf + static_for<0, V_KRepeat, 1>{}([&](auto k0) { + v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + + v_coord[VPageIndexDim] + k0.value] * + stride_v; + }); + v_dram_window.update_page_idx(v_offsets); + } + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_buf); + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, + v_shuffle_tmp)); // store the prefetch + } + else + { + auto v_lds_window_tmp = get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{}); + store_tile(v_lds_window_tmp, + tile_elementwise_in(v_element_func, v_buf)); // store next v_buf + } + if constexpr(i_k1 < k1_loops - 1) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + i_total_loops++; + if(i_total_loops < num_total_loop) + { + page_idx += kN0; + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k; + }); + k_dram_window.update_page_idx(k_offsets); + if constexpr(k1_loops >= 2 && + LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) + __builtin_amdgcn_s_barrier(); + async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), + k_dram_window, + number<-1>{}, + k_oob_ck, + k_pre_np); + move_tile_window(k_dram_window, {0, kK0}); + } + // tail + { + block_sync_lds(); + gemm_1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + get_slice_tile( + v_lds_window, + sequence<(LdsSeq.at(number{})) * kN1, 0>{}, + sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); + } + } while(i_total_loops < num_total_loop); + + // store lse + if constexpr(kStoreLSE) + { + auto lse = make_static_distributed_tensor(m.get_tile_distribution()); + + constexpr auto lse_spans = decltype(lse)::get_distributed_spans(); + sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } + } +#else + lse(i_idx) = m_[i_idx] + log(l_[i_idx]); +#endif + }); + + store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); + } + + // finally, O + constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); + + sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + const auto tmp = [&]() { + if constexpr(FmhaMask::IsMasking) + { + return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx]; + } + else + return 1 / l[i_idx]; + }(); + sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + o_acc(i_j_idx) *= tmp; + }); + }); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + PositionEncoding position_encoding, + float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, + void* smem_ptr, + const index_t* page_idx, + const index_t stride_k, + const index_t stride_v, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + randval_dram_block_window_tmp, + lse_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + position_encoding, + scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + page_idx, + stride_k, + stride_v, + dropout); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp new file mode 100644 index 0000000000..02731ca8f8 --- /dev/null +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +// This pipeline is qkv all located in LDS +using BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy = + BlockFmhaPipelineQXKSVSCustomPolicy; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp index 809c58f1d1..4d1c38e079 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp @@ -27,6 +27,7 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -46,15 +47,21 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -128,7 +135,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -150,6 +159,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -453,9 +465,34 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); @@ -574,7 +611,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s_new[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s_new[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s_new[i_j_idx] - get_validated_m(m[i_idx])); @@ -603,8 +647,15 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -711,7 +762,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS } else { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -757,7 +815,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile @@ -771,6 +831,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -794,6 +857,9 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index ce80dba5eb..7f5f79d7a7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -26,6 +26,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -45,15 +46,21 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kIsPagedKV = Problem::kIsPagedKV; - static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kIsPagedKV = Problem::kIsPagedKV; + static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -127,7 +134,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -149,6 +158,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -401,9 +413,28 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); @@ -497,7 +528,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -522,8 +560,16 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -620,7 +666,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } else { - lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse_acc(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse_acc(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -662,7 +715,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS typename VPageBlockNavigator, typename BiasDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile @@ -676,6 +731,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate void* smem_ptr) const { @@ -699,6 +757,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, kv_l2p_offset, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9a5208c025..f35c00c268 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -20,6 +20,7 @@ template struct BlockFmhaPipelineProblem @@ -36,6 +37,7 @@ struct BlockFmhaPipelineProblem using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; @@ -50,6 +52,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kHasDropout = Traits::kHasDropout; @@ -69,6 +72,7 @@ template struct BlockFmhaFwdSplitKVPipelineProblem @@ -84,6 +88,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using Traits = remove_cvref_t; @@ -98,6 +103,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap; static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8a4a925b81..29f183c613 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -5,8 +5,8 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -28,6 +28,7 @@ struct BlockFmhaPipelineQRKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -47,14 +48,20 @@ struct BlockFmhaPipelineQRKSVS static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -101,7 +108,7 @@ struct BlockFmhaPipelineQRKSVS else { return 1; - }; + } } }(); @@ -128,7 +135,9 @@ struct BlockFmhaPipelineQRKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -147,6 +156,9 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -380,9 +392,28 @@ struct BlockFmhaPipelineQRKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -398,7 +429,12 @@ struct BlockFmhaPipelineQRKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -450,7 +486,14 @@ struct BlockFmhaPipelineQRKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -475,8 +518,16 @@ struct BlockFmhaPipelineQRKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -574,7 +625,14 @@ struct BlockFmhaPipelineQRKSVS } else { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -614,7 +672,9 @@ struct BlockFmhaPipelineQRKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -625,6 +685,9 @@ struct BlockFmhaPipelineQRKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -645,6 +708,9 @@ struct BlockFmhaPipelineQRKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 67354fc72d..7af3902dc5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVSAsync using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -53,13 +54,19 @@ struct BlockFmhaPipelineQRKSVSAsync // only need special care about seq_k padding (oob need set -INF of p instead of zero) static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true && Problem::kPadHeadDimV == true); - static constexpr bool kPadSeqLenQ = true; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) - static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kPadSeqLenQ = true; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) + static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this @@ -153,7 +160,9 @@ struct BlockFmhaPipelineQRKSVSAsync typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -172,6 +181,9 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -435,9 +447,34 @@ struct BlockFmhaPipelineQRKSVSAsync else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } +#else + for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i) + { + apply_logits_transform(s_acc.thread_buf_[i]); + } #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -454,7 +491,12 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -543,7 +585,14 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -568,8 +617,15 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -695,7 +751,14 @@ struct BlockFmhaPipelineQRKSVSAsync } else { - lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -735,7 +798,9 @@ struct BlockFmhaPipelineQRKSVSAsync typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -746,6 +811,9 @@ struct BlockFmhaPipelineQRKSVSAsync FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -766,6 +834,9 @@ struct BlockFmhaPipelineQRKSVSAsync mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 7be6a347f5..4efcd871dc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce.hpp" namespace ck_tile { @@ -27,6 +28,7 @@ struct BlockFmhaPipelineQSKSVS using PDataType = remove_cvref_t; using OaccDataType = remove_cvref_t; using ODataType = remove_cvref_t; + using AttentionVariant = remove_cvref_t; using FmhaMask = remove_cvref_t; using BlockFmhaShape = remove_cvref_t; @@ -44,14 +46,21 @@ struct BlockFmhaPipelineQSKSVS static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim; - static constexpr bool kIsGroupMode = Problem::kIsGroupMode; - static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr auto BiasEnum = Problem::BiasEnum; - static constexpr bool kStoreLSE = Problem::kStoreLSE; - static constexpr bool kHasDropout = Problem::kHasDropout; + static constexpr bool kIsGroupMode = Problem::kIsGroupMode; + static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; + static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap; + static constexpr auto BiasEnum = Problem::BiasEnum; + static constexpr bool kStoreLSE = Problem::kStoreLSE; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 && + (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS || + !kHasLogitsSoftCap)) || + (!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap)); + // last dimension vector length used to create tensor view(and decide buffer_load vector length) // ... together with tensor distribution. tensor dist should able to overwrite this static constexpr index_t kAlignmentQ = @@ -95,7 +104,9 @@ struct BlockFmhaPipelineQSKSVS return 1; } else + { return 1; + } } }(); @@ -122,7 +133,9 @@ struct BlockFmhaPipelineQSKSVS typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -141,6 +154,9 @@ struct BlockFmhaPipelineQSKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& /* unused_dropout */) const { @@ -380,9 +396,28 @@ struct BlockFmhaPipelineQSKSVS else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + if constexpr(kHasLogitsSoftCap) + { + auto apply_logits_transform = + [&variant, &variant_params, &block_indices](auto& x) { + x = variant.LogitsTransform(variant_params, + variant.QueryTransform(variant_params, x), + block_indices.batch_idx, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }; #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout(apply_logits_transform, s_acc); +#else + tile_elementwise_inout(apply_logits_transform, s_acc); #endif + } + else + { +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); +#endif + } } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -398,7 +433,12 @@ struct BlockFmhaPipelineQSKSVS s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + return !variant.LogitsMask(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); }); } } @@ -450,7 +490,14 @@ struct BlockFmhaPipelineQSKSVS } else { - p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); + } + else + { + p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); + } } #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); @@ -481,8 +528,16 @@ struct BlockFmhaPipelineQSKSVS } else { - auto row_max = scale_s * get_validated_m(m[i_idx]); - return exp2(scale_s * m_old[i_idx] - row_max); + if constexpr(kHasLogitsSoftCap) + { + + return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); + } + else + { + auto row_max = scale_s * get_validated_m(m[i_idx]); + return exp2(scale_s * m_old[i_idx] - row_max); + } } }(); #else @@ -571,7 +626,14 @@ struct BlockFmhaPipelineQSKSVS } else { - lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + if constexpr(kHasLogitsSoftCap) + { + lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); + } + else + { + lse(i_idx) = m_[i_idx] * scale_s / C_LOG2E + log(l_[i_idx]); + } } #else lse(i_idx) = m_[i_idx] + log(l_[i_idx]); @@ -611,7 +673,9 @@ struct BlockFmhaPipelineQSKSVS typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, - typename PositionEncoding> + typename PositionEncoding, + typename AttentionVariantParams, + typename BlockIndices> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -622,6 +686,9 @@ struct BlockFmhaPipelineQSKSVS FmhaMask mask, PositionEncoding position_encoding, float scale_s, + const AttentionVariant& variant, + const AttentionVariantParams& variant_params, + const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout) const { @@ -642,6 +709,9 @@ struct BlockFmhaPipelineQSKSVS mask, position_encoding, scale_s, + variant, + variant_params, + block_indices, smem_ptr, dropout); } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 8d2d848558..4530b58d85 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -13,6 +13,7 @@ template 1 or fwd training is running */ @@ -51,6 +54,7 @@ struct TileFmhaFwdSplitKVTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_; static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kStoreLSE = kStoreLSE_; diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 6a7ccd2472..664294fe18 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -19,6 +19,10 @@ namespace ck_tile { #define MOE_SORTING_USE_EX_KERNEL 1 #endif +#ifndef MOE_SORTING_FUSE_MP_01 +#define MOE_SORTING_FUSE_MP_01 0 +#endif + // clang-format off // [indexing implementation-1] // using M_a as constexpr block_size to partition all tokens into different slices @@ -118,7 +122,7 @@ CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_ex int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here int smem_rows = [&](){ index_t target_occupancy_ = 2; - constexpr index_t total_ = 65536 / sizeof(int); + constexpr index_t total_ = get_smem_capacity() / sizeof(index_t); constexpr index_t sub_unroll = 8; constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt // at lease 2 lines, one for sub_token unroll, one for cumsum @@ -250,7 +254,7 @@ struct MoeSortingKernel { #if MOE_SORTING_USE_EX_KERNEL auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts); - return smem_rows * smem_cols * sizeof(int); + return smem_rows * smem_cols * sizeof(index_t); #else const auto blocks = BlockSize(h); // usually num_experts is power of 2, we pad 1 dword here for the row-size @@ -1063,17 +1067,43 @@ CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens) return (tokens + chunk - 1) / chunk * chunk; }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_elem(index_t tokens, index_t num_experts) +// 4-i32 mesh, 2-i16 mseh, 1-i8 mesh +CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_, + index_t /*num_experts_*/, + index_t topk_) +{ + // small token case, let's run mesh with dword score board + if(tokens_ < 512) + return 4; + else + { + if(topk_ >= 255) + return 2; // 16bit mesh + else + return 1; // 8bit mesh if small enough + } +} + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens, + index_t num_experts, + index_t topk) { index_t row_size = moe_sorting_mp_mesh_stride(tokens); - return num_experts * row_size; + index_t elem = num_experts * row_size; + return elem * moe_sorting_mesh_byte_size(tokens, num_experts, topk); }; -CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_elem(index_t num_experts) +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts) { constexpr index_t chunk = 32; index_t row_size = num_experts + 1; - return (row_size + chunk - 1) / chunk * chunk; + return (row_size + chunk - 1) / chunk * chunk * sizeof(index_t); +}; + +CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size() +{ + constexpr index_t chunk = 32; + return chunk * sizeof(index_t); }; template @@ -1245,15 +1275,20 @@ CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_) } // return size in byte -CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_) { - index_t elem = impl::moe_sorting_mp_mesh_elem(tokens_, num_experts_) + - impl::moe_sorting_mp_cumsum_elem(num_experts_); - return elem * sizeof(index_t); + index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) + + impl::moe_sorting_mp_cumsum_smem_size(num_experts_) +#if MOE_SORTING_FUSE_MP_01 + + impl::moe_sorting_mp_sem_smem_size(); +#else + ; +#endif + return s_; } // return size in byte -CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_) +CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_) { #if 1 if(moe_sorting_is_oneshot(tokens_, num_experts_)) @@ -1262,10 +1297,10 @@ CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts } else { - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); } #else - return moe_sorting_mp_get_workspace_size(tokens_, num_experts_); + return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_); #endif } @@ -1320,6 +1355,7 @@ struct MoeSortingMultiPhaseKernel_P0 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1371,22 +1407,21 @@ struct MoeSortingMultiPhaseKernel_P0 { using topk_id_t = ext_vector_t; - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); - IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; #pragma unroll Problem::SubTokenTile - for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; i += blockDim.x) + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += gridDim.x * BLOCK_SIZE) { auto x = p_topk_ids[i]; static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { IndexType eid = x[j.value]; // ext_vector_type must use int to [] uint32_t curr_token_id, curr_topk_id; kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); - p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = + (curr_topk_id + 1) & 0xffff; }); } } @@ -1400,6 +1435,7 @@ struct MoeSortingMultiPhaseKernel_P1 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1420,9 +1456,9 @@ struct MoeSortingMultiPhaseKernel_P1 Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); return k; @@ -1444,13 +1480,11 @@ struct MoeSortingMultiPhaseKernel_P1 int eid = blockIdx.x; - constexpr index_t index_pack = 4; // always packed - using r_t = ext_vector_t; // always use int32x4 + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 r_t* p_expert_mesh = reinterpret_cast( - reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); const IndexType* p_local_expert_mask = static_cast(kargs.p_local_expert_mask); IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); @@ -1502,6 +1536,197 @@ struct MoeSortingMultiPhaseKernel_P1 } }; +#if MOE_SORTING_FUSE_MP_01 +template +struct MoeSortingMultiPhaseKernel_P01 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; // [tokens, topk] + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_expert_sem; // [1] + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + index_t wg_count; // used for semaphore + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto get_num_cu() + { + index_t num_cu = [&]() { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + HIP_CHECK_ERROR(hipGetDevice(&dev)); + HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; + }(); + return num_cu; + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_expert_sem = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) + + impl::moe_sorting_mp_cumsum_smem_size(h.num_experts)); + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.wg_count = WGCounts(h); + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h) + { + index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile; + index_t elem_cnt = (total_elem + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // no more than grid_size + return min(elem_cnt, GridSize(h)); + } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize() + { + return BLOCK_SIZE / warpSize * sizeof(IndexType); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + workgroup_barrier wb{reinterpret_cast(kargs.p_expert_sem)}; + + { + using topk_id_t = ext_vector_t; + + const topk_id_t* p_topk_ids = reinterpret_cast(kargs.p_topk_ids); + IndexType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + index_t total_elem = kargs.tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile; + +#pragma unroll Problem::SubTokenTile + for(index_t i = blockIdx.x * BLOCK_SIZE + threadIdx.x; i < total_elem; + i += BLOCK_SIZE * gridDim.x) + { + auto x = p_topk_ids[i]; + static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) { + IndexType eid = x[j.value]; // ext_vector_type must use int to [] + uint32_t curr_token_id, curr_topk_id; + kargs.topk_mdiv.divmod( + i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id); + p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1; + }); + } + if(static_cast(blockIdx.x) < kargs.wg_count) + { + wb.inc(); + } + } + + { + __shared__ char smem[GetSmemSize()]; + int eid = blockIdx.x; + + // early exist in case of extra atomic wait + if(eid >= kargs.num_experts) + return; + + wb.wait_lt(kargs.wg_count); + + for(; eid < kargs.num_experts; eid += gridDim.x) + { + // if(threadIdx.x == 0) + // printf("!!! bid:%d, eid:%d (%d, %d)\n", + // static_cast(blockIdx.x), + // eid, + // kargs.num_experts, + // static_cast(blockDim.x)); + constexpr index_t index_pack = 4; // always packed + using r_t = ext_vector_t; // always use int32x4 + r_t* p_expert_mesh = reinterpret_cast( + reinterpret_cast(kargs.p_expert_mesh) + eid * kargs.mesh_stride); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + + auto f_sum = [](auto x_, auto y_) { return x_ + y_; }; + + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + + if constexpr(Problem::LocalExpertMasking) + { + IndexType mask = p_local_expert_mask[eid]; + if(mask == 0) + continue; // skip + } + + index_t cnt = 0; // per-wave cnt + for(int i = 0; i < loops; i++) + { + int position = i * BLOCK_SIZE + threadIdx.x; + r_t v{0}; + if(position < (kargs.mesh_stride / index_pack)) + v = p_expert_mesh[position]; + index_t local_sum = 0; + static_for<0, index_pack, 1>{}( + [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; }); + cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum); + } + + index_t lane_id = threadIdx.x % warpSize; + index_t wave_id = threadIdx.x / warpSize; + + // reduce cross wave + IndexType* s = reinterpret_cast(smem); + __syncthreads(); + if(lane_id == 0) + { + s[wave_id] = cnt; + } + __syncthreads(); + + if(threadIdx.x == 0) + { + index_t c = 0; + for(auto i = 0; i < (BLOCK_SIZE / warpSize); i++) + { + c += s[i]; + } + p_expert_cumsum[eid] = c; + } + } + } + } +}; +#endif + // token count cumsum template struct MoeSortingMultiPhaseKernel_P2 @@ -1510,6 +1735,7 @@ struct MoeSortingMultiPhaseKernel_P2 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1536,10 +1762,9 @@ struct MoeSortingMultiPhaseKernel_P2 { Kargs k; k.p_local_expert_mask = h.p_local_expert_mask; - // k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; k.p_sorted_expert_ids = h.p_sorted_expert_ids; @@ -1566,7 +1791,8 @@ struct MoeSortingMultiPhaseKernel_P2 // in byte CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() { - return 2 * BLOCK_SIZE * sizeof(IndexType); + // return 2 * BLOCK_SIZE * sizeof(IndexType); + return (4 + 2 * BLOCK_SIZE / warpSize) * sizeof(IndexType); } // reduce single pixel within a wave @@ -1718,6 +1944,7 @@ struct MoeSortingMultiPhaseKernel_P3 using IndexType = typename Problem::IndexType; using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; static constexpr index_t BLOCK_SIZE = 256; static constexpr index_t OCCUPANCY = 2; // hard coded @@ -1749,9 +1976,9 @@ struct MoeSortingMultiPhaseKernel_P3 k.p_sorted_token_ids = h.p_sorted_token_ids; k.p_sorted_weights = h.p_sorted_weights; k.p_expert_mesh = h.p_ws; - k.p_expert_cumsum = - reinterpret_cast(reinterpret_cast(h.p_ws) + - impl::moe_sorting_mp_mesh_elem(h.tokens, h.num_experts)); + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); k.tokens = h.tokens; k.num_experts = h.num_experts; k.topk_mdiv = mdiv{static_cast(h.topk)}; @@ -1782,9 +2009,6 @@ struct MoeSortingMultiPhaseKernel_P3 const WeightType* p_weights = static_cast(kargs.p_weights); WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); - static_assert(Problem::SubTokenTile == 1 || Problem::SubTokenTile == 2 || - Problem::SubTokenTile == 4); - int eid = blockIdx.x; int wave_id = threadIdx.x / warpSize; int lane_id = threadIdx.x % warpSize; @@ -1866,6 +2090,495 @@ struct MoeSortingMultiPhaseKernel_P3 } }; +namespace impl { +// we use dynamic LDS size here +CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_) +{ + constexpr index_t BLOCK_SIZE = 256; // hardcoded 256 + const index_t expert_cumsum_elem = num_experts_ + 1; + return (4 + 2 * BLOCK_SIZE / warpSize + expert_cumsum_elem) * sizeof(int); +} +} // namespace impl + +// token count cumsum +template +struct MoeSortingMultiPhaseKernel_P23 +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + using MeshType = typename Problem::MeshType; + + static constexpr index_t BLOCK_SIZE = 256; + static constexpr index_t OCCUPANCY = 2; // hard coded + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + struct Kargs + { + const void* p_weights; + const void* p_local_expert_mask; // [expert] + void* p_expert_mesh; // [expert, tokens] + void* p_expert_cumsum; // [expert + 1] + void* p_total_tokens_post_pad; // [1] + void* p_sorted_expert_ids; + + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_moe_buf; + + index_t tokens; + index_t num_experts; + index_t mesh_stride; // mesh_stride for p_expert_mesh + mdiv unit_size_mdiv; + mdiv topk_mdiv; + long_index_t moe_buf_bytes; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_weights = h.p_weights; + k.p_local_expert_mask = h.p_local_expert_mask; + k.p_expert_mesh = h.p_ws; + k.p_expert_cumsum = reinterpret_cast( + reinterpret_cast(h.p_ws) + + impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk)); + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + + k.p_moe_buf = h.p_moe_buf; + + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + + k.moe_buf_bytes = h.moe_buf_bytes; + + return k; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // use 1 block to cumsum + // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + return dim3(h.num_experts + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BLOCK_SIZE * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(BLOCK_SIZE); } + + // only use this at host ! + CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) + { + const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts); + const auto smem_sf = BLOCK_SIZE * 4 * sizeof(IndexType); + return max(smem_23, smem_sf); + } + + // reduce single pixel within a wave + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(static_cast(blockIdx.x) >= kargs.num_experts) + { + impl::moe_buf_set_zero_kernel( + reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes, + blockIdx.x - kargs.num_experts); + return; + } + + extern __shared__ char smem[]; + { + IndexType* s = reinterpret_cast(smem); + + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* p_expert_cumsum = reinterpret_cast(kargs.p_expert_cumsum); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + IndexType* p_total_tokens_post_pad = + reinterpret_cast(kargs.p_total_tokens_post_pad); + IndexType* p_sorted_expert_ids = + reinterpret_cast(kargs.p_sorted_expert_ids); + + const index_t loops = (kargs.num_experts + BLOCK_SIZE - 1) / BLOCK_SIZE; + index_t wave_id = threadIdx.x / warpSize; + index_t lane_id = threadIdx.x % warpSize; + + IndexType prev_cumsum_a = 0; + IndexType prev_cumsum_b = 0; + + for(index_t i = 0; i < loops; i++) + { + index_t position = i * BLOCK_SIZE + threadIdx.x; + IndexType a_ = 0; // token count for a expert + IndexType b_ = 0; // mask for a expert + if(position < kargs.num_experts) + { + a_ = p_expert_cumsum[position]; + if constexpr(Problem::LocalExpertMasking) + b_ = p_local_expert_mask[position]; + } + + int blocks_pers_expert = + kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1); + // pad token + int padded_blocks_per_expert = [&]() { + int x_ = [&]() { + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + // if local_cnt is zero, blocks_pers_expert will be zero + // this is what we want to achieve + return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor; + } + else + { + return max(blocks_pers_expert, 1); + } + }(); + if constexpr(Problem::LocalExpertMasking) + { + return b_ ? x_ : 0; + } + else + return x_; + }(); + + IndexType cumsum_a = padded_blocks_per_expert; + IndexType cumsum_b = b_; + + // Note: we first cumsum local round, then add previous cumsum + impl::moe_sorting_wave_cumsum(cumsum_a); + impl::moe_sorting_wave_cumsum(cumsum_b); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum_a; + s[4 + wave_id + BLOCK_SIZE / warpSize] = cumsum_b; + } + + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev_a = s[4 + i_w]; + IndexType prev_b = s[4 + i_w + BLOCK_SIZE / warpSize]; + prev_a = wave_id > i_w ? prev_a : 0; // mask out + prev_b = wave_id > i_w ? prev_b : 0; // mask out + cumsum_a += prev_a; + cumsum_b += prev_b; + }); + + // Now let's add previous cumsum + cumsum_a += prev_cumsum_a; + cumsum_b += prev_cumsum_b; + + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[2] = cumsum_a; // store the last cumsum + s[3] = cumsum_b; + } + + IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt + IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt + + __syncthreads(); + prev_cumsum_a = s[2]; + prev_cumsum_b = s[3]; + + if(position < kargs.num_experts) + { + p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor; + } + + { + if(blockIdx.x == 0) + { + if constexpr(Problem::LocalExpertMasking) + { + if(b_) + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = out_1; + } + } + } + else + { + for(int j = 0; j < blocks_pers_expert; j++) + { + p_sorted_expert_ids[out_0 + j] = position; + } + } + } + } + } + + if(threadIdx.x == 0) + { + auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; + if(blockIdx.x == 0) + p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad; + } + } + + __syncthreads(); + + { + const IndexType* p_local_expert_mask = + static_cast(kargs.p_local_expert_mask); + IndexType* s = reinterpret_cast(smem); + MeshType* p_expert_mesh = reinterpret_cast(kargs.p_expert_mesh); + IndexType* p_sorted_token_ids = reinterpret_cast(kargs.p_sorted_token_ids); + IndexType* p_expert_cumsum_smem = s + 4 + 2 * BLOCK_SIZE / warpSize; + const WeightType* p_weights = static_cast(kargs.p_weights); + WeightType* p_sorted_weights = reinterpret_cast(kargs.p_sorted_weights); + + int eid = blockIdx.x; + int wave_id = threadIdx.x / warpSize; + int lane_id = threadIdx.x % warpSize; + int e_start = p_expert_cumsum_smem[eid]; + int e_end = p_expert_cumsum_smem[eid + 1]; + if constexpr(Problem::SkipExpertsWithZeroTokens) + { + if(e_start == e_end) + return; + } + + if constexpr(Problem::LocalExpertMasking) + { + int e_mask = p_local_expert_mask[eid]; + if(e_mask == 0) + return; // skip empty expert + } + + // cumsum one by one + constexpr index_t index_pack = Problem::SubTokenTile; // always packed + using r_t = ext_vector_t; // always use int32x4 + using d_t = ext_vector_t; + int loops = (kargs.mesh_stride / index_pack + BLOCK_SIZE - 1) / BLOCK_SIZE; + int prev_cumsum = 0; + + for(int i = 0; i < loops; i++) + { + int i_token_pack = i * BLOCK_SIZE + threadIdx.x; + r_t x_v = 0; + if(i_token_pack < (kargs.tokens + index_pack - 1) / index_pack) + { + x_v = reinterpret_cast(p_expert_mesh + + eid * kargs.mesh_stride)[i_token_pack]; + } + + r_t x_r; +#if 0 + if constexpr(index_pack != 1) + { + // shuffle, we must have contiguout thread holds contiguout token + __syncthreads(); + reinterpret_cast(s)[threadIdx.x] = x_v; + __syncthreads(); + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + x_r[j] = reinterpret_cast(s)[threadIdx.x + j * BLOCK_SIZE]; + }); + } +#else + x_r = x_v; +#endif + { +#if 0 +#pragma unroll + for(int j = 0; j < index_pack / 2; j++) + { + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * BLOCK_SIZE; + index_t x = x_d[j]; + int i_topk = x - 1; // topk of this token + int i_show = x != 0 ? 1 : 0; // has this token or not + int cumsum = i_show; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position = cumsum - i_show; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk]; + } + } +#endif + { + d_t i_topk; + d_t i_show; + // = 0; + int cumsum_store = 0; + + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + i_topk[j] = static_cast(x_r[j] - 1); + i_show[j] = static_cast(x_r[j] != 0 ? 1 : 0); + cumsum_store += i_show[j]; + }); + int cumsum = cumsum_store; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + prev_cumsum = s[0]; // update the last cumsum + + int position = cumsum - cumsum_store; + static_for<0, index_pack, 1>{}([&](auto j_) { + constexpr auto j = j_.value; + // int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x + j * + // BLOCK_SIZE; + int i_token = + i * BLOCK_SIZE * index_pack + threadIdx.x * index_pack + j; + + if(i_show[j]) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position] = + MOE_SORTING_MOCK_ID(i_token, i_topk[j]); +#else + p_sorted_token_ids[e_start + position] = i_token; +#endif + p_sorted_weights[e_start + position] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]]; + } + position += i_show[j]; + }); + +#if 0 + int i_token = i * BLOCK_SIZE * index_pack + threadIdx.x * 2 + j * BLOCK_SIZE * 2; + index_t x = x_d[j]; + index_t x0 = static_cast(x & 0xffff); + index_t x1 = static_cast(x >> 16); + int i_topk_0 = x0 - 1; // topk of this token + int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not + int i_topk_1 = x1 - 1; // topk of this token + int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not + int cumsum = i_show_0 + i_show_1; + impl::moe_sorting_wave_cumsum(cumsum); + + __syncthreads(); + if(lane_id == warpSize - 1) + { + s[4 + wave_id] = cumsum; + } + __syncthreads(); + + // reduce cross wave + static_for<0, BLOCK_SIZE / warpSize - 1, 1>{}([&](auto i_w) { + IndexType prev = s[4 + i_w]; + prev = wave_id > i_w ? prev : 0; // mask out + cumsum += prev; + }); + cumsum += prev_cumsum; // add previous round cumsum + if(threadIdx.x == BLOCK_SIZE - 1) + { + s[0] = cumsum; + } + __syncthreads(); + + int position_0 = cumsum - i_show_0 - i_show_1; + prev_cumsum = s[0]; // update the last cumsum + + if(i_show_0) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_0] = + MOE_SORTING_MOCK_ID(i_token, i_topk_0); +#else + p_sorted_token_ids[e_start + position_0] = i_token; +#endif + p_sorted_weights[e_start + position_0] = + p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0]; + } + + int position_1 = cumsum - i_show_1; + + if(i_show_1) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[e_start + position_1] = + MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1); +#else + p_sorted_token_ids[e_start + position_1] = i_token + 1; +#endif + p_sorted_weights[e_start + position_1] = + p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1]; + } +#endif + } + } + } + + for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += BLOCK_SIZE) + { +#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID + p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(kargs.tokens, kargs.topk_mdiv.divisor); +#else + p_sorted_token_ids[i] = tokens; +#endif + p_sorted_weights[i] = static_cast(0.0); + } + } + } +}; + #undef MOE_SORTING_MOCK_ID } // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp index a98e0d7652..39bc6ca93e 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_problem.hpp @@ -50,20 +50,23 @@ struct MoeSortingProblemEx }; template struct MoeSortingProblemMp { // TODO: this kernel only support warp per row using WeightType = remove_cvref_t; + using MeshType = remove_cvref_t; using IndexType = remove_cvref_t; static constexpr index_t SubTokenTile = SubTokenTile_; static constexpr bool LocalExpertMasking = LocalExpertMasking_; static constexpr bool SkipExpertsWithZeroTokens = SkipExpertsWithZeroTokens_; - static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4); + static_assert(SubTokenTile == 1 || SubTokenTile == 2 || SubTokenTile == 4 || + SubTokenTile == 8 || SubTokenTile == 16); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 794f7f21f2..35f5170179 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -44,8 +44,11 @@ #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 38ed108f6d..c4d527da63 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -298,7 +298,7 @@ struct BlockUniversalGemmAsBsCr using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); ALdsTile a_warp_tile_; - ALdsTile b_warp_tile_; + BLdsTile b_warp_tile_; template CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index dfb6bfae58..d495c0d950 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -142,15 +142,7 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } - else - { - this->template RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); - } + this->RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e5b9d17bac..9c25104cd7 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" namespace ck_tile { @@ -607,9 +608,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -621,7 +620,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -639,9 +639,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } /** @@ -659,9 +658,7 @@ struct GemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, CDataType* c_ptr, @@ -674,7 +671,8 @@ struct GemmKernel { // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); + MakeGemmTensorViews( + a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -691,9 +689,8 @@ struct GemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I2); - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, smem_ptr_0); + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, smem_ptr_0); } CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const @@ -717,7 +714,9 @@ struct GemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GetSmemSize()]; - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm2LDS(a_ptr, b_ptr, @@ -729,38 +728,15 @@ struct GemmKernel i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } } else { - if(kargs.k_batch == 1) + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) { RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); } - else - { - if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm( - a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n); - } - } } } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index 0e0ee9dbd8..6535f612f1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -217,17 +217,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 ////////////// global window & register ///////////////// // A DRAM tile window for load auto a_copy_dram_window = - make_tile_window_linear(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); // B DRAM tile window for load auto b_copy_dram_window = - make_tile_window_linear(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); // A register tile for global load constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); @@ -317,25 +317,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 BLdsTile b_block_tile1; auto a_lds_ld_window0 = - make_tile_window_linear(a_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto a_lds_ld_window1 = - make_tile_window_linear(a_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - ALdsTileDistr); + make_tile_window(a_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + ALdsTileDistr); auto b_lds_ld_window0 = - make_tile_window_linear(b_lds_block0, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block0, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); auto b_lds_ld_window1 = - make_tile_window_linear(b_lds_block1, - make_tuple(number{}, number{}), - {0, 0}, - BLdsTileDistr); + make_tile_window(b_lds_block1, + make_tuple(number{}, number{}), + {0, 0}, + BLdsTileDistr); + + static_assert( + !(is_tile_window_linear_v)&&!(is_tile_window_linear_v)&&!( + is_tile_window_linear_v< + decltype(b_lds_ld_window0)>)&&!(is_tile_window_linear_v), + "LDS windows must not be linear"); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index e528847438..f6920f1c57 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -17,56 +17,6 @@ namespace ck_tile { struct GemmPipelineAgBgCrCompV4DefaultPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using namespace ck_tile; - - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{} / KPack, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackB(); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, number{}, number{}), - make_tuple(number<(kNPerBlock)*KPack>{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple( - make_pass_through_transform(number{}), - make_merge_transform(make_tuple(number{}, number{}))), - make_tuple(sequence<1>{}, sequence<0, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return b_lds_block_desc; - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index f833ccc849..0b38e7789e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -32,6 +32,8 @@ struct GemmPipelineProblemBase static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr bool kPadM = Traits::kPadM; @@ -194,7 +196,8 @@ struct UniversalGemmPipelineProblem static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto TailNum = TailNum_; - static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool TransposeC = Traits::TransposeC; + static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c504a51ad0..6890cf2f64 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -19,6 +19,245 @@ struct UniversalGemmBasePolicy static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using ADataType = remove_cvref_t; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + /** + * @brief Create LDS block descriptor for B tensor. + * + * @tparam Problem Gemm pipeline problem. + * @return B tensor LDS block descriptor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + // using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + +#if 1 + // if constexpr(std::is_same_v) + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = TileDistributionEncodingPattern2D; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif + } + /** * @brief Get the maximum global memory vector load size. * @@ -301,7 +540,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() { - constexpr auto a_lds_desc = Derived::template MakeALdsBlockDescriptor(); + constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); constexpr index_t smem_size_a = integer_least_multiple( sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); return smem_size_a; @@ -310,7 +549,7 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() { - constexpr auto b_lds_desc = Derived::template MakeBLdsBlockDescriptor(); + constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); return smem_size_b; @@ -330,245 +569,6 @@ struct UniversalGemmBasePolicy struct UniversalGemmPipelineAgBgCrPolicy : public UniversalGemmBasePolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() - { - using ADataType = remove_cvref_t; - - constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); - - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return a_lds_block_desc; - } - - /** - * @brief Create LDS block descriptor for B tensor. - * - * @tparam Problem Gemm pipeline problem. - * @return B tensor LDS block descriptor. - */ - template - CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() - { - // using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; - - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - -#if 1 - // if constexpr(std::is_same_v) - { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - BK0 * number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; - } -#else - else // B is Row Major - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = TileDistributionEncodingPattern2D; - - constexpr auto BK0 = number{}; - constexpr auto BK1 = number{}; - // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N0 = TileEncodingPattern::X0; - constexpr auto N1 = NPerBlock / N0; - - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; - - // constexpr auto KThreadWrite = - // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; - constexpr auto K0PerThreadRead = BK0 / KThreadRead; - - constexpr auto kfold = - (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - // b_lds_block_desc_unmerged, - // make_tuple(make_merge_transform_v3_division_mod( - // make_tuple(number{}, - // number{}, - // number{}, - // number{})), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}, number{})), - // make_pass_through_transform(BK1)), - // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), - // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - // return b_lds_block_desc_bk0_n_bk1; - return b_lds_block_desc_kn; - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( - // make_tuple(BK0, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); - - // constexpr auto b_lds_block_desc = transform_tensor_descriptor( - // b_lds_block_desc_bk0_n_bk1, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod(make_tuple(BK0, - // number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - - // return b_lds_block_desc; - } -#endif - } - template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() { @@ -580,7 +580,9 @@ struct UniversalGemmPipelineAgBgCrPolicy WarpTile::at(I0), WarpTile::at(I1), WarpTile::at(I2), - Problem::TransposeC>; + Problem::TransposeC, + false, + Problem::UseStructuredSparsity>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy + bool TransposeC_ = false, + bool UseStructuredSparsity_ = false> struct TileGemmUniversalTraits { static constexpr bool kPadM = kPadM_; @@ -49,7 +51,8 @@ struct TileGemmUniversalTraits using BLayout = BLayout_; using CLayout = CLayout_; - static constexpr bool TransposeC = TransposeC_; + static constexpr bool TransposeC = TransposeC_; + static constexpr bool UseStructuredSparsity = UseStructuredSparsity_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 1fd12973f6..f050a8e382 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,6 +7,9 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp" + namespace ck_tile { // fp16 @@ -17,13 +20,24 @@ using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -41,20 +55,50 @@ using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl, + 1>>; + +using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution = + WarpGemmImpl, + 1>>; +#endif + +#if defined(__gfx950__) +using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaF16F16F32M4N64K16 = WarpGemmImpl, @@ -64,21 +108,38 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; -// bf16 +// fp16 2:4 structured sparsity +using WarpGemmSmfmacF16F16F32M32N32K16 = WarpGemmSmfmacImpl>>; +using WarpGemmSmfmacF16F16F32M16N16K32 = WarpGemmSmfmacImpl>>; + +// bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +#else using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, @@ -97,20 +158,38 @@ using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl>>; +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = WarpGemmImpl, 2>>; +#endif +#if defined(__gfx950__) +using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = + WarpGemmImpl>>; +#else using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = WarpGemmImpl, 2>>; +#endif using WarpGemmMfmaBf16Bf16F32M4N64K16 = WarpGemmImpl, @@ -125,6 +204,14 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -134,6 +221,44 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl>>; + +using WarpGemmMfma_f32_32x32x64_fp8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_fp8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_fp8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_32x32x64_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index e7d4c37966..93ccdb5f57 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -356,7 +356,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution } }; -template +template struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB { using Impl = remove_cvref_t; @@ -373,6 +373,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB static constexpr index_t kN = Impl::kM; static constexpr index_t kK = Impl::kK; static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t SFactor = SFactor_; // group how many CM1 together CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } @@ -386,7 +387,7 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB tuple>, sequence<2>, sequence<1>>; - +#if 0 using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>, sequence<2, 2>, sequence<0, 2>>; +#else + // TODO: more test not only 32x32 + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>; +#endif template // c_vec += a_vec * b_vec CK_TILE_DEVICE void operator()(CVecType& c_vec, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 21a865e792..69d22496f1 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -61,6 +61,69 @@ enum class WGAttrCtlEnum DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \ } +// V_MFMA_F32_16x16x32_BF16 +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; // FP16 template struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 @@ -188,6 +251,69 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 } }; +template +struct WarpGemmAttributeMfmaImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return bit_cast( + __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImplF16F16F32M4N64K4 { @@ -622,7 +748,395 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4 } }; +// gfx950 +template +struct WarpGemmAttributeMfmaImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = bf16_t; + using BDataType = bf16_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl) + else + { +#if defined(__gfx950__) + c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#elif defined(__gfx908__) + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0); +#elif defined(__gfx90a__) || defined(__gfx94__) + CVecType c_vec{0.f}; + static_for<0, 2, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#elif defined(__gfx908__) + CVecType c_vec{0.f}; + static_for<0, 4, 1>{}([&](auto k) { + c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reinterpret_cast&>(a_vec) + .template get_as>()[number{}], + reinterpret_cast&>(b_vec) + .template get_as>()[number{}], + c_vec, + 0, + 0, + 0); + }); + return c_vec; +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + // FP8 +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v") + } + } + else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva) + { + if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v") + } + else if constexpr(std::is_same_v && std::is_same_v) + { + DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v") + } + } + else + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx94__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_f32_316x16x32_bf8_bf8( + bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + template struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base { @@ -809,11 +1323,17 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; - +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; template using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +using WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base; + template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; @@ -822,6 +1342,202 @@ template using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base; +template +struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 128; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base; + +template +struct WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = AType_; + using BDataType = BType_; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 64; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 32; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + // c_vec += a_vec * b_vec + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a, + // opsel, scale_b) +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0); + else if constexpr(std::is_same_v && std::is_same_v) + c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; +#endif + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { +#if defined(__gfx950__) + if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0)); + else if constexpr(std::is_same_v && std::is_same_v) + return bit_cast(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0)); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + return CVecType{0.f}; +#endif + } +}; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_fp8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_fp8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + +template +using WarpGemmAttributeMfmaImpl_f32_32x32x64_bf8_bf8 = + WarpGemmAttributeMfmaImpl_f32_32x32x64_f8_bf8_base; + // int8 template struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8 diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp new file mode 100644 index 0000000000..84cdf17d66 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp" + +namespace ck_tile { + +/** + * @brief Class describing structured sparsity mfma instructions. + * + * @paragraph Overview "Overview" + * Currently only 2:4 structured sparsity is supported, which is based on requirement that in every + * groups of four continuous elements there are at most two non-zero, which results in processing + * only half of elements in smfmac instruction. Because of structured sparsity A vector in smfmac + * instruction will be smaller than B vector by the factor of CompressionRatio. The indexes of + * non-zero elements are stored in `index` which is an additional parameter to assembly instruction. + * Every pair of two bit indexes are containing information about which two elements in current + * group of 4 values are non-zero and should be used inside smfmac instruction. Structured sparsity + * format is supported only for A matrix for now. + */ +template +struct WarpGemmAttributeSmfmac +{ + using Impl = remove_cvref_t; + + using ADataType = typename Impl::ADataType; + using BDataType = typename Impl::BDataType; + using IdxDataType = typename Impl::IdxDataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + using BVecType = typename Impl::BVecType; + using CVecType = typename Impl::CVecType; + + static constexpr index_t kM = Impl::kM; + static constexpr index_t kN = Impl::kN; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + static constexpr index_t kCompressionRatio = Impl::CompressionRatio; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; } + + static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, + "Multi-block WarpGemmAttributeSmfmacImpl is not supported"); + + using AWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using BWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + + using CWarpDstrEncoding = tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 1>, + sequence<0, 2>>; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { + Impl{}(c_vec, a_vec, b_vec, idx, bool_constant{}); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp new file mode 100644 index 0000000000..cd6cd3a399 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_smfmac_impl.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "warp_gemm_attribute_mfma_impl.hpp" + +namespace ck_tile { + +// fp16 2:4 structured sparsity + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M32N32K16 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 32; + static constexpr index_t kN = 32; + static constexpr index_t kK = 16; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 32; + static constexpr index_t kBNLane = 32; + static constexpr index_t kABKLane = 2; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 2; + static constexpr index_t kCNLane = 32; + static constexpr index_t kCM0PerLane = 4; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx94_) or defined(__gfx95_) + c_vec = __builtin_amdgcn_smfmac_f32_32x32x16_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +template +struct WarpGemmAttributeSmfmacImplF16F16F32M16N16K32 +{ + static constexpr WGAttrCtlEnum Ctrl = Ctrl_; + using ADataType = fp16_t; + using BDataType = fp16_t; + using IdxDataType = int32_t; + using CDataType = float; + + using AVecType = ext_vector_t; + using BVecType = ext_vector_t; + using CVecType = ext_vector_t; + + static constexpr index_t kM = 16; + static constexpr index_t kN = 16; + static constexpr index_t kK = 32; + + static constexpr index_t kAMBlock = 1; + static constexpr index_t kBNBlock = 1; + + static constexpr index_t kAMLane = 16; + static constexpr index_t kBNLane = 16; + static constexpr index_t kABKLane = 4; + static constexpr index_t kABKPerLane = 8; + + static constexpr index_t kCMLane = 4; + static constexpr index_t kCNLane = 16; + static constexpr index_t kCM0PerLane = 1; + static constexpr index_t kCM1PerLane = 4; + + static constexpr index_t CompressionRatio = 2; + + // c_vec += a_vec * b_vec[idx] + template + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + const int32_t& idx, + bool_constant = {}) const + { +#if defined(__gfx94_) or defined(__gfx95_) + c_vec = __builtin_amdgcn_smfmac_f32_16x16x32_f16(a_vec, b_vec, c_vec, idx, 0, 0); +#else + ck_tile::ignore = c_vec; + ck_tile::ignore = a_vec; + ck_tile::ignore = b_vec; + ck_tile::ignore = idx; +#endif + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9c319b5e5f..b2f5d56d01 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,7 +16,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> struct WarpGemmMfmaDispatcher; // clang-format off @@ -35,6 +36,10 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +// fp16 2:4 structural sparsity +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M32N32K16; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmSmfmacF16F16F32M16N16K32; }; + // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; @@ -52,14 +57,30 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8; }; + +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8; }; + // clang-format on } // namespace impl @@ -70,7 +91,8 @@ template + bool SwizzleA = false, + bool UseStructuredSparsity = false> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + SwizzleA, + UseStructuredSparsity>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp new file mode 100644 index 0000000000..9e028ddab0 --- /dev/null +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +namespace ck_tile { + +template +struct WarpGemmSmfmacImpl +{ + using WarpGemmAttribute = remove_cvref_t; + + static constexpr index_t kM = WarpGemmAttribute::kM; + static constexpr index_t kN = WarpGemmAttribute::kN; + static constexpr index_t kK = WarpGemmAttribute::kK; + /// @brief The number of elements in K dimension processed by single thread in wavefront. + /// + /// @note Note that WarpGemm may run MFMA instruction multiple times (on different K). + /// In such situation this value reflects this fact. + static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread; + + using ADataType = typename WarpGemmAttribute::ADataType; + using BDataType = typename WarpGemmAttribute::BDataType; + using CDataType = typename WarpGemmAttribute::CDataType; + + using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; + using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; + + using AWarpDstr = remove_cvref_t; + using BWarpDstr = remove_cvref_t; + using CWarpDstr = remove_cvref_t; + + using AWarpTensor = static_distributed_tensor; + using BWarpTensor = static_distributed_tensor; + using CWarpTensor = static_distributed_tensor; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() + { + return WarpGemmAttribute_::get_num_of_access(); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Compress A vector for 2:4 structured sparsity instruction by moving all non-zero + /// elements into lower part of a_vec to half its effective size. + /// + /// @param a_vec Vector to be compressed. + /// + /// @return Four 2-bit indexes of non-zero elements locations + /// + template + CK_TILE_DEVICE int32_t compress_a(AVec& a_vec) const + { + int32_t idx = 0b11101110; + + static_for<0, 2, 1>{}([&](auto i) { + ADataType nonzero_elems[2] = {a_vec[i * 4 + 2], a_vec[i * 4 + 3]}; + int32_t non_zero_pos = 0; + + static_for<0, 3, 1>{}([&](auto j) { + if(a_vec[i * 4 + j] != 0.0f) + { + nonzero_elems[non_zero_pos] = a_vec[i * 4 + j]; + idx &= ~(0b11 << 2 * (i * 2 + non_zero_pos)); + idx |= j << 2 * (i * 2 + non_zero_pos); + ++non_zero_pos; + } + }); + a_vec[i * 2] = nonzero_elems[0]; + a_vec[i * 2 + 1] = nonzero_elems[1]; + }); + + return idx; + } + + template + CK_TILE_DEVICE void + operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant = {}) const + { + static_assert(detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v && + detail::is_similiar_distributed_tensor_v); + constexpr auto CompressionRatio = WarpGemmAttribute::kCompressionRatio; + + using AVec = ext_vector_t; + using AVecCompressed = + ext_vector_t; + using BVec = ext_vector_t; + using CVec = ext_vector_t; + + constexpr auto I0 = number<0>{}; + + auto a_vec = a.get_thread_buffer().template get_as()[I0]; + const auto b_vec = b.get_thread_buffer().template get_as()[I0]; + auto c_vec = c.get_thread_buffer().template get_as()[I0]; + + const int32_t idx = compress_a(a_vec); + + // @TODO can we simply set a_vec_pruned to a_vec[0:3]? + const AVecCompressed a_vec_pruned = {a_vec[0], a_vec[1], a_vec[2], a_vec[3]}; + + // c_vec += a_vec * b_vec[idx] + WarpGemmAttribute{}(c_vec, a_vec_pruned, b_vec, idx, bool_constant{}); + + c.get_thread_buffer().template set_as(I0, c_vec); + } +}; + +} // namespace ck_tile diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 7e2482807d..c8d284a1d7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_a = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); + else + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_b = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); + else + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index af735925ed..120bf7484a 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -18,15 +18,19 @@ namespace host { template + index_t ActivationType_ = 0, + bool MulRoutedWeight = true, + typename ComputeTypeA = CDataType, + typename ComputeTypeB = ComputeTypeA> struct ReferenceMoeGemm : public device::BaseOperator { // Argument + static constexpr auto ActivationType = ActivationType_; struct Argument : public device::BaseArgument { Argument(const Tensor& sorted_token_ids, @@ -34,8 +38,11 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_t, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -44,8 +51,11 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id_{max_token_id}, sorted_tile_size_{sorted_tile_size}, a_t_k_{a_t_k}, + a_scale_t_{a_scale_t}, b_e_n_k_{b_e_n_k}, + b_scale_e_n_{b_scale_e_n}, c_t_k_n_{c_t_k_n}, + d2_{d2}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} @@ -57,8 +67,11 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id_; index_t sorted_tile_size_; const Tensor& a_t_k_; + const Tensor& a_scale_t_; const Tensor& b_e_n_k_; + const Tensor& b_scale_e_n_; Tensor& c_t_k_n_; + const Tensor& d2_; AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; @@ -72,15 +85,22 @@ struct ReferenceMoeGemm : public device::BaseOperator float Run(const Argument& arg) { - auto f_mk_kn_mn = [&](auto m, auto n) { + static_assert(ActivationType < 2, "Not supported activation type"); + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc_up{0}; + ComputeTypeB v_b_up{0}; AccDataType v_acc{0}; + ComputeTypeA v_a{0}; ComputeTypeB v_b{0}; + const int t = arg.sorted_token_ids_(m) & 0xffffff; const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; const int e = arg.expert_ids_(m / arg.sorted_tile_size_); const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + D2DataType v_topk_w = arg.d2_(m, 0); // expert if(t < token_cnt) { for(int k = 0; k < K; ++k) @@ -96,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else @@ -106,37 +126,79 @@ struct ReferenceMoeGemm : public device::BaseOperator // same for B matrix if constexpr(is_same_v) { - uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; - uint8_t i4 = 0; + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data; + uint8_t i4 = 0; + uint8_t i4_up = 0; if(k % 2 == 1) - i4 = (i4x2 >> 0) & 0xf; + { + i4 = (i4x2 >> 0) & 0xf; + i4_up = (i4x2_up >> 0) & 0xf; + } else - i4 = (i4x2 >> 4) & 0xf; + { + i4 = (i4x2 >> 4) & 0xf; + i4_up = (i4x2_up >> 4) & 0xf; + } #if CK_USE_PK4_LAYOUT_SHUFFLE - v_b = i4_to_f32_gfx9(i4); + v_b = i4_to_f32_gfx9(i4); + v_b_up = i4_to_f32_gfx9(i4_up); #else - v_b = i4 - 8; + v_b = i4 - 8; + v_b_up = i4_up - 8; #endif } else { arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n)); } v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc_up += ck::type_convert(v_a) * + ck::type_convert(v_b_up); } CDataType v_c{0}; + CDataType v_c_up{0}; + if constexpr(MulRoutedWeight) + { + v_acc *= v_topk_w; + v_acc_up *= v_topk_w; + } arg.c_element_op_(v_c, v_acc); + arg.c_element_op_(v_c_up, v_acc_up); - arg.c_t_k_n_(t, topk_id, n) = v_c; + if constexpr(ActivationType == 1) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Silu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } + else if constexpr(ActivationType == 0) + { + v_c = v_c * arg.b_scale_e_n_(e, n) * arg.a_scale_t_(t); + if constexpr(is_same_v) + { + v_c_up *= 16; + v_c *= 16; + } + tensor_operation::element_wise::Gelu{}(v_c, v_c); + v_c_up = v_c_up * arg.b_scale_e_n_(e, n + full_n) * arg.a_scale_t_(t); + arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up; + } } }; const ck::index_t max_token_id = arg.max_token_id_(0); - make_ParallelTensorFunctor( - f_mk_kn_mn, max_token_id, arg.c_t_k_n_.mDesc.GetLengths()[2])( + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( std::thread::hardware_concurrency()); return 0; @@ -162,8 +224,11 @@ struct ReferenceMoeGemm : public device::BaseOperator const Tensor& max_token_id, const index_t sorted_tile_size, const Tensor& a_t_k, + const Tensor& a_scale_n, const Tensor& b_e_n_k, + const Tensor& b_scale_e_n, Tensor& c_t_k_n, + const Tensor& d2, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) @@ -173,8 +238,11 @@ struct ReferenceMoeGemm : public device::BaseOperator max_token_id, sorted_tile_size, a_t_k, + a_scale_n, b_e_n_k, + b_scale_e_n, c_t_k_n, + d2, a_element_op, b_element_op, c_element_op}; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 1e8a086bc4..5c932fcb18 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,6 +25,7 @@ template struct ReferenceMoeGemm2 : public device::BaseOperator @@ -143,7 +144,14 @@ struct ReferenceMoeGemm2 : public device::BaseOperator CDataType v_c{0}; D0DataType v_d0 = arg.d0_(m, n); // a D0DataType v_d1 = arg.d1_(e, n); // b - arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + if constexpr(MulRoutedWeight) + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, v_topk_w); + } + else + { + arg.c_element_op_(v_c, v_acc, v_d0, v_d1, 1.f); + } arg.c_t_n_(t, n) += v_c; } }; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 649f130c41..e8fdcf1acd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - a_m_k_scaled(m, k) = - type_convert(arg.a_m_k_(m, k)) * - type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + else + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + else + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } } } @@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - b_k_n_scaled(k, n) = - type_convert(arg.b_k_n_(k, n)) * - type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + else + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + else + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp index 07891ea932..90a9fa381d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_wp.hpp @@ -18,173 +18,6 @@ namespace device { namespace instance { #if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8)) -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - F16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - -void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - std::vector, - Row, - F8, - F8, - Tuple, - BF16, - PassThrough, - PassThrough, - MultiplyMultiply>>>& - instances); - void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( std::vector && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( @@ -612,33 +250,6 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instances_v2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instances_v2( - op_ptrs); - - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p1( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instances_p2( - op_ptrs); - add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p1( op_ptrs); add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instances_p2( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp new file mode 100644 index 0000000000..4af5143f45 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMX> +{ + using DeviceOp = DeviceGemmMX; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + + add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp index 4218c51ca3..79212e16dd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -7,521 +7,22 @@ #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_WMMA +#include "gemm_universal_wmma.inc" +#endif +#ifdef CK_USE_XDL +#include "gemm_universal_xdl.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( - std::vector>>& - instances); -#endif -#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( - std::vector>>& - instances); -#endif template > op_ptrs; +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + op_ptrs); + } + } +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA + +#ifdef CK_USE_XDL #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) @@ -822,6 +399,7 @@ struct DeviceOperationInstanceFactory< } #endif +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -831,7 +409,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); } } - +#endif +#ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -842,6 +421,8 @@ struct DeviceOperationInstanceFactory< op_ptrs); } } +#endif +#endif // CK_USE_XDL return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc new file mode 100644 index 0000000000..1396437326 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_wmma.inc @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc new file mode 100644 index 0000000000..f0de713834 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_xdl.inc @@ -0,0 +1,521 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v1_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_v2_nkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index ae6fabd0bd..5c0d7283f2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -54,6 +54,28 @@ using device_grouped_conv_bwd_data_xdl_f16_generic_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f16_16_16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, @@ -108,6 +130,27 @@ using device_grouped_conv_bwd_data_xdl_bf16_generic_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, @@ -162,6 +205,28 @@ using device_grouped_conv_bwd_data_xdl_f32_generic_instances = // clang-format on >; +template +using device_grouped_conv_bwd_data_xdl_f32_16_16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> + // clang-format on + >; + template , S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, @@ -194,7 +259,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4> // clang-format on >; @@ -218,7 +283,7 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp index f491474d38..17ffa65d1c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -55,14 +56,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -71,7 +74,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -79,17 +84,17 @@ using device_grouped_conv_fwd_xdl_bf16_comp_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3> // clang-format on >; @@ -99,15 +104,17 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5> // clang-format on >; @@ -117,14 +124,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -133,14 +142,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> // clang-format on >; @@ -150,22 +161,24 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -174,17 +187,19 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -194,14 +209,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; @@ -210,14 +227,16 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4> // clang-format on >; @@ -227,18 +246,20 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_comp_instances_part2 = std::tuple< // clang-format off - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, // AGPR Spill when use permuted lds layout. so, use padding for these two. - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index c9ea462316..df24b4cbcb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -51,7 +52,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -59,7 +62,7 @@ using device_grouped_conv_fwd_xdl_bf16_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -68,7 +71,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -76,24 +81,24 @@ using device_grouped_conv_fwd_xdl_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -102,7 +107,30 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_bf16_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -110,7 +138,7 @@ using device_grouped_conv_fwd_xdl_f16_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -119,7 +147,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -127,24 +157,24 @@ using device_grouped_conv_fwd_xdl_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -153,7 +183,30 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f16_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -161,7 +214,7 @@ using device_grouped_conv_fwd_xdl_f32_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1> // clang-format on >; @@ -170,7 +223,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -178,24 +233,24 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; @@ -204,7 +259,30 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_16x16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 4> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_generic_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -212,7 +290,7 @@ using device_grouped_conv_fwd_xdl_int8_generic_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> // clang-format on >; @@ -221,7 +299,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -229,24 +309,24 @@ using device_grouped_conv_fwd_xdl_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -255,7 +335,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -264,24 +346,24 @@ using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_FP8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> #endif // clang-format on >; @@ -291,7 +373,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -300,24 +384,24 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_FP8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> #endif // clang-format on >; @@ -327,7 +411,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| @@ -336,24 +422,24 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #ifdef CK_ENABLE_BF8 // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> #endif // clang-format on >; @@ -363,7 +449,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| @@ -372,24 +460,24 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> #endif // clang-format on >; @@ -399,7 +487,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| @@ -408,24 +498,24 @@ using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple< //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> #endif // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp index 0a85cde3bc..6bb6d255f3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp @@ -25,6 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -36,7 +37,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -44,10 +47,10 @@ using device_grouped_conv_fwd_xdl_large_tensor_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -56,7 +59,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -64,10 +69,10 @@ using device_grouped_conv_fwd_xdl_large_tensor_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; @@ -76,7 +81,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -84,9 +91,9 @@ using device_grouped_conv_fwd_xdl_large_tensor_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> // clang-format on >; @@ -95,7 +102,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_large_tensor_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -103,9 +112,9 @@ using device_grouped_conv_fwd_xdl_large_tensor_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // clang-format on >; } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp index 1f381af08c..195367ffd7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -33,6 +33,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -52,7 +53,9 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -60,27 +63,27 @@ using device_grouped_conv_fwd_xdl_bf16_mem_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Latency friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -90,34 +93,36 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f16_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -127,30 +132,32 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_f32_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; @@ -160,34 +167,36 @@ template + BlockGemmPipelineScheduler BlkGemmPipeSched, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_int8_mem_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, // Memory friendly - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 153cc61b09..182c785978 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -25,6 +25,7 @@ using Empty_Tuple = ck::Tuple<>; using namespace ck::tensor_layout::convolution; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; @@ -38,7 +39,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| @@ -46,9 +49,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> // clang-format on >; @@ -58,16 +61,18 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups| //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge| //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| | //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BF16, BF16, LoopScheduler::Default, 32> // clang-format on >; @@ -76,7 +81,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -84,9 +91,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> // clang-format on >; @@ -96,7 +103,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -104,9 +113,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> // clang-format on >; @@ -115,7 +124,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -123,9 +134,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F32, F32, LoopScheduler::Default, 32> // clang-format on >; @@ -134,7 +145,9 @@ template + ConvolutionForwardSpecialization ConvSpec, + typename DsDataTypes = Tuple<>, + typename OutElementOp = PassThrough> using device_grouped_conv_fwd_xdl_merged_groups_int8_instances = std::tuple< // clang-format off //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -142,9 +155,9 @@ using device_grouped_conv_fwd_xdl_merged_groups_int8_instances = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, int8_t, int8_t, LoopScheduler::Default, 32> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 12695f4f16..e9ff75a91d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -109,6 +109,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_FP32 @@ -117,6 +119,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -126,6 +130,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + op_ptrs); } #endif } @@ -167,6 +173,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( op_ptrs); } @@ -177,6 +185,8 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( op_ptrs); } @@ -188,6 +198,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( op_ptrs); + add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( op_ptrs); } @@ -237,6 +249,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + op_ptrs); } #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 @@ -255,6 +269,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -264,6 +280,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + op_ptrs); } #endif } @@ -308,6 +326,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( op_ptrs); } @@ -319,6 +339,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( op_ptrs); } @@ -330,6 +352,8 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( op_ptrs); + add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( + op_ptrs); add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( op_ptrs); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 5be8f29e99..c723be0db8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -69,6 +69,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( @@ -84,6 +98,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( @@ -99,6 +127,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -162,6 +204,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( @@ -310,6 +408,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( @@ -325,6 +437,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances); #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( @@ -403,6 +529,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances( std::vector) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances( @@ -221,6 +222,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances( @@ -243,6 +245,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances(op_ptrs); add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( op_ptrs); add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( @@ -288,6 +291,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -484,6 +491,7 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances(op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances( @@ -503,6 +511,8 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + op_ptrs); add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( op_ptrs); add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( @@ -523,7 +533,73 @@ struct DeviceOperationInstanceFactory && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + op_ptrs); + } #endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances(op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + +#endif // CK_USE_XDL #ifdef CK_USE_WMMA if constexpr(NumDimSpatial == 2 && is_same_v && @@ -537,7 +613,6 @@ struct DeviceOperationInstanceFactory +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_bias_relu_xdl.inc" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_XDL + // layout NHWGC/GKYXC/NHWGK + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } + // layout NDHWGC/GKZYXC/NDHWGK + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + op_ptrs); + } +#endif + } +#endif // CK_USE_XDL + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc new file mode 100644 index 0000000000..1935f123a8 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu_xdl.inc @@ -0,0 +1,242 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances); + +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc index 48bc8942a8..b830bdce71 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_comp_xdl.inc @@ -283,6 +283,111 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances( PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc index 3900c7a0fb..00351ceefd 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_inter_xdl.inc @@ -171,6 +171,55 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instan PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc index b7815f5023..bd44116057 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc @@ -171,6 +171,55 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instan PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc index 0ea24d0929..df4e95007d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc @@ -51,20 +51,6 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -291,20 +236,6 @@ void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); #endif } // namespace instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b934b9aef6..b018737932 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -137,6 +137,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -153,6 +167,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -169,6 +197,20 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -267,6 +309,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -283,6 +339,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -299,6 +369,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 @@ -382,6 +466,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 @@ -398,6 +496,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP8 @@ -446,6 +558,20 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_INT8 @@ -517,6 +643,97 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( F8>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc index 966b883301..9f54c4b633 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc @@ -178,6 +178,55 @@ void add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_in PassThrough>>>& instances); #endif +// grouped conv3d forward, NGCDHW/GKCZYX/NGKDHW +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index a16418ec7e..25ea3b2ae4 100755 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -60,6 +60,13 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build MX instances if gfx950 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + message("removing MX instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() # Do not build WMMA instances if gfx11 targets are not on the target list foreach(source IN LISTS ARGN) if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") @@ -74,48 +81,64 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - # Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") + message("removing gemm_multiply_multiply_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") - message("removing gemm_multiply_multiply_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() + if(NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "gemm_wmma_universal" AND source MATCHES "_f8_") + message("removing gemm_universal_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() endforeach() - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") - message("removing gemm_universal_f8 instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - endif() + message("remaining instances: ${ARGN}") #only continue if there are some source files left on the list if(ARGN) set(INST_OBJ) foreach(source IN LISTS ARGN) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(source MATCHES "_xdl") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(source MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(source MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() + + if(source MATCHES "_mx") + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + endif() + #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() else() if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) - endif() + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + endif() + endif() + if(source MATCHES "gemm_wmma_universal" AND source MATCHES "f8") + list(FILTER INST_TARGETS INCLUDE REGEX "gfx12") endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) @@ -234,6 +257,10 @@ FOREACH(subdir_path ${dir_list}) if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9")) message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY MX_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx950")) + message("Found only MX instances, but gfx950 is not on the targets list. Skipping.") + set(add_inst 0) endif() if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") @@ -442,4 +469,3 @@ set(DEV_OPS_INC_DIRS ${PROJECT_SOURCE_DIR}/library/include/ck/ ) rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) - diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt index 37233ac5b4..743a0272f7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/CMakeLists.txt @@ -2,18 +2,6 @@ set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES) list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -21,18 +9,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp - f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -41,18 +17,6 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp ) -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -60,18 +24,6 @@ set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp index 4266ab9aa3..4613a0f24d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_bf16/device_gemm_multiply_multiply_wp_xdl_f8_f8_bf16_mk_mfma_mn.hpp @@ -100,7 +100,17 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_ //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on >; @@ -161,13 +171,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Compute friendly // 256x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 16, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -180,13 +190,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 224x[64, 256, 32]x128 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 14, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 7, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 14, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 14, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 14, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -198,13 +208,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 192x[64, 256, 32]x128, 192x[64]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 12, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 12, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 12, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 12, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -216,13 +226,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 10, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 5, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 10, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 10, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; template @@ -234,10 +244,10 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; @@ -249,11 +259,11 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma16x1 //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp index 94e44ee600..dc9db8889a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_multiply_wp/f8_f8_f16/device_gemm_multiply_multiply_wp_xdl_f8_f8_f16_mk_mfma_mn.hpp @@ -115,7 +115,13 @@ using device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, - DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 16, 32, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<4, 4, 1>, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..0442bed130 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -0,0 +1,18 @@ +# ONLY MX_KERNELS +set(GEMM_MX_INSTANCES) + +list(APPEND GEMM_MX_INSTANCES + device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp + device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp + ) + + +set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + + +add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..8dc21cbf1f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF8 = bf8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp new file mode 100644 index 0000000000..2b6ccdbeda --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp new file mode 100644 index 0000000000..d3f74b2907 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..c75e779fea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..ac09df7ea2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..05914e06b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..68363de523 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; + +template +using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..f4e59cf92d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt index ade65eacf3..18eeefa522 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -1,7 +1,17 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_UNIVERSAL_INSTANCES) -list(APPEND GEMM_UNIVERSAL_INSTANCES +list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp + + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -18,7 +28,7 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp - + device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -57,6 +67,16 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instance.cpp ) +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") @@ -80,6 +100,9 @@ set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm set_source_files_properties(device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp + device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -134,25 +157,28 @@ list(APPEND GEMM_UNIVERSAL_INSTANCES device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp ) - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") - set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_nkpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp new file mode 100644 index 0000000000..5d3bb3f7b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..c9a730de68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp new file mode 100644 index 0000000000..6c3a641f9f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..cd88edec59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..b700e78d3d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..9951c02251 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7b4cd64d33 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3a607c4178 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp new file mode 100644 index 0000000000..3751dc5a11 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..3971802415 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp new file mode 100644 index 0000000000..222b49eb7d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..36901b4f38 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 0000000000..6960375ed6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + // Configurations used during development, mainly for testing + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..bbc8b92217 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..7f71cf6f59 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 256, 64, 64, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 32, 64, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..331ca8b2ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp new file mode 100644 index 0000000000..2fca3551b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..5087a9d719 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..244eb69190 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute| + //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| TypeA| TypeB| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | | | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8>, + DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, F8, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 0000000000..89df765517 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f8_bf16/device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + if(ck::is_gfx11_supported()) + return; + + add_device_operation_instances( + instances, device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 913ebd3a12..0ef09c55ee 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -7,9 +7,15 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 226dca5083..bf775b04c0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] + void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_bf16_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 64fbf8bbf2..1a3c80e5cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] + void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f16_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index f9351d96f2..96623a5161 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k] + void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f32_instances<2, - GNHWK, - GKYXC, - Empty_Tuple, - GNHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + GNHWK, + GKYXC, + Empty_Tuple, + GNHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..f3aded5043 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 23aeeaf505..e8c6bc7cbe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index beeda26690..3f94d30a55 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NGKHW, + GKCYX, + Empty_Tuple, + NGCHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp index a1d768f4eb..b5e89c9b7c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -9,7 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 75e7f61f8a..11e0fc6073 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 231e894be0..a63dd712b6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index dbaece1123..e4b4165928 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k] + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( std::vector{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp index a67b11f1cf..db048679bd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -49,14 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp index 5c0391a25f..ee9507a80a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instanc Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp index 726276c461..132d3c8411 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_2x_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp index 8b7bdec2a8..a7deb969ba 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp @@ -49,14 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp index c66114b9a3..d2732547fa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance.cpp @@ -52,15 +52,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_part2_instance Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp index 93e07e08fb..8a0caebc9f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -48,14 +48,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp index 6acbb7475c..e45df1e107 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" @@ -50,14 +50,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( NHWGK, ConvFwd1x1S1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_comp_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); - if(ck::get_device_name() != "gfx950") { add_device_operation_instances( @@ -86,15 +78,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_part2<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } if(ck::get_device_name() == "gfx950") @@ -125,15 +108,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_int8_comp_instances_2x<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 2afbfdc386..078221f89f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 822ef51e00..3a481dd204 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 79a1fb99a8..5add0f8add 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( Empty_Tuple, GNHWK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..0843325287 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp index 6c5d9b5b94..4ca1b2b85e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp new file mode 100644 index 0000000000..a82e800bb1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp index f1ccad2add..e3a12fd5f4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp new file mode 100644 index 0000000000..5918f2479f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp index de7e416e48..467a33deb3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp @@ -30,6 +30,20 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f32_instances( Empty_Tuple, NGKHW, ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<2, + NGCHW, + GKCYX, + Empty_Tuple, + NGKHW, + ConvFwd1x1S1P0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..5b8b62010a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index e567c0df75..0257c7d315 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp new file mode 100644 index 0000000000..7ca27e21a7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 3e42184996..2715506fe2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp new file mode 100644 index 0000000000..74cdbde0ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index c035d4c3da..8d3e4d91b1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp index 5c425effd8..465fa927a5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" @@ -46,14 +46,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances( Empty_Tuple, NHWGK, ConvFwd1x1S1P0>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp index e8a763c527..87423801cb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp index 3ae3fb5186..ebb213461a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp index cb7e912936..c2c8a099b2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp index d787f4b048..11cb853f0d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp index 5644289790..1992d7f7c1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp index 5b12dad5a3..2b8fd3d9db 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp index f667481fa4..5579ec62cc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance NHWGK, ConvFwd1x1S1P0, Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp index 2ff2c7f51f..77f3df2c11 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" @@ -49,15 +49,6 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance NHWGK, ConvFwd1x1S1P0, Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..98b0b1c4cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +add_instance_library(device_grouped_conv2d_fwd_bias_relu_instance + xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp + xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp + + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp + xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp new file mode 100644 index 0000000000..75acd604ee --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_2x_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..69a8a4bd9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp new file mode 100644 index 0000000000..043c724e4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/comp/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instance.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_part2_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..c58631e169 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_16x16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..cd80f2875f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp similarity index 75% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 5cb313b3ca..a6286b55e8 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -1,37 +1,38 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( +void add_device_grouped_conv2d_fwd_bias_relu_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector, NHWGK, - int8_t, - int8_t, - Empty_Tuple, - int8_t, + BF16, + BF16, + Tuple, + BF16, PassThrough, PassThrough, - PassThrough>>>& instances) + AddRelu>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..0736325b05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..0d35ab1b05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/mem/device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddRelu>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..253e8b196e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_bias_relu_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector, + NHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances_2x<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); + } + else + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<2, + NHWGC, + GKYXC, + Tuple, + NHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index a656c79289..4bb05e5000 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -6,15 +6,22 @@ set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_16_16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16_16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_16_16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp + wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index 8331ea1fda..41f0235063 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, + // wo, k] void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_bf16_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 1885d49c81..03b8285631 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, + // wo, k] void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f16_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index 77135fcc05..59526ba9bc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho, + // wo, k] void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( std::vector{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataDefault>{}); // 2. Filter1x1Stride1Pad0 add_device_operation_instances( instances, - device_grouped_conv_bwd_data_xdl_f32_instances<3, - GNDHWK, - GKZYXC, - Empty_Tuple, - GNDHWC, - ConvBwdDataFilter1x1Stride1Pad0>{}); + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + GNDHWK, + GKZYXC, + Empty_Tuple, + GNDHWC, + ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp new file mode 100644 index 0000000000..3f90c8b907 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 663d41fe0b..f9989dec13 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index ac0ab44ce3..071d34b94a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NDHWGK, + GKZYXC, + Empty_Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 50d5cce73d..77127bf7f9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp" @@ -8,8 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp index a9a6b4d281..943c5bab26 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp index eec3944078..bada2507c2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances( std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3, + NGKDHW, + GKCZYX, + Empty_Tuple, + NGCDHW, + ConvBwdDataDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp index a596482ca8..f1c6f53bf3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -9,8 +9,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, -// g, k] + void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp new file mode 100644 index 0000000000..43241454a5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp similarity index 65% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp index a8f723dfec..85a1c9137c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -1,37 +1,42 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances) + PassThrough>>>&) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<2, - GNHWC, - GKYXC, - Empty_Tuple, - GNHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); +#endif + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp new file mode 100644 index 0000000000..9b8bf4fa42 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instances( + std::vector>>& instances) +{ + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp new file mode 100644 index 0000000000..d02d9f6778 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp similarity index 70% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp index 8c621543a9..eaac75ee9e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_part2_instance.cpp @@ -1,37 +1,42 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances) + PassThrough>>>&) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + if(ck::get_device_name() != "gfx950") + { +#if 0 // TODO: Improve compilation time and enable these instances + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); +#endif + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp new file mode 100644 index 0000000000..696ea7f34e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..8f113b5234 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp new file mode 100644 index 0000000000..1395447660 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp new file mode 100644 index 0000000000..43b3565c74 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..3b5068d605 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp new file mode 100644 index 0000000000..060eebebc1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp new file mode 100644 index 0000000000..0ddf5bfa48 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f16_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp new file mode 100644 index 0000000000..85b088f416 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp new file mode 100644 index 0000000000..dc4f7be9c0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_16x16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_16x16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp new file mode 100644 index 0000000000..2b3e596355 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..fac3098341 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..f3eccc7dc8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp new file mode 100644 index 0000000000..abea0bea81 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp new file mode 100644 index 0000000000..ba5d9fb1de --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp new file mode 100644 index 0000000000..5a2c4a0d5b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp new file mode 100644 index 0000000000..701b8eb4a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp similarity index 66% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp index 7649f86971..71bde2faa5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp @@ -1,38 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, -// g, k] -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_bf16_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp similarity index 67% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp index fa378af1ee..2e71b76256 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp @@ -1,21 +1,20 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, -// wo, k] -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f16_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f16_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp similarity index 67% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp index d41416fd4a..8a53dea612 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/merged_groups/device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp @@ -1,38 +1,44 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, -// wo, k] -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + +void add_device_grouped_conv3d_fwd_xdl_merged_groups_ngcdhw_gkczyx_ngkdhw_f32_instances( std::vector>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_i8_instances<3, - GNDHWC, - GKZYXC, - Empty_Tuple, - GNDHWK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_f32_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd3x3>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..afdddfec70 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,16 @@ +# ONLY XDL_KERNELS +set(GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp + + xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + + xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp + xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp + + xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +) + +add_instance_library(device_grouped_conv3d_fwd_bias_relu_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp new file mode 100644 index 0000000000..9819f0ea0b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/comp/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); + } +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp new file mode 100644 index 0000000000..dc3fc7a4bf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_16x16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..a9a8ff8459 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 70% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 8a7bc26178..e58e879973 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/large_tensor/device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -1,38 +1,38 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_large_tensor_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, -// g, k] -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector, NDHWGK, - F16, - F16, - Empty_Tuple, - F16, + BF16, + BF16, + Tuple, + BF16, PassThrough, PassThrough, - PassThrough>>>& instances) + AddRelu>>>& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_wmma_f16_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - Empty_Tuple, - PassThrough, - ConvFwdOddC>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_large_tensor_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp new file mode 100644 index 0000000000..e76052c6e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Interwave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Interwave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Interwave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp new file mode 100644 index 0000000000..0593f3f46a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/mem/device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Intrawave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Intrawave, + Tuple, + AddRelu>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Intrawave, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..6552f26f88 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_relu/xdl/merged_groups/device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_relu_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + AddRelu>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddRelu>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_merged_groups_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd3x3, + Tuple, + AddRelu>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp index 2054ffbbb3..f7b1d5f1f8 100644 --- a/profiler/include/profiler/profile_gemm_universal_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -9,7 +9,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_universal.hpp" diff --git a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp old mode 100644 new mode 100755 index d145ab1766..e625fae808 --- a/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp @@ -166,7 +166,7 @@ bool profile_gemm_universal_streamk_impl(int do_verification, 0, 1, 2, 3, 4}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile Stream-K+ DP, // 2:2-tile Stream-K + DP - if(Grid_size != -1) + if(Grid_size == -1) { grid_size_list = {Grid_size}; } diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index 6b24be7d1f..4e0ced347d 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -34,7 +34,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, int init_method, bool do_log, bool time_kernel, - const ck::utils::conv::ConvParam& conv_param) + const ck::utils::conv::ConvParam& conv_param, + ck::index_t split_k = 1) { using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -88,6 +89,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, // reset input to zero in_device_buf.SetZero(); + float max_accumulated_value = 0; if(do_verification) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdDataGetWorkSpaceSize(argument_ptr.get()); DeviceMem workspace_dev(workspace_sz); @@ -150,7 +154,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / avg_time; std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << ", SplitK " << split_k_for_run + << std::endl; if(tflops > best_tflops) { @@ -158,13 +163,39 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, best_tflops = tflops; best_avg_time = avg_time; best_gb_per_sec = gb_per_sec; + best_split_k = split_k_for_run; } if(do_verification) { in_device_buf.FromDevice(in_device.mData.data()); - pass = pass & ck::utils::check_err(in_device, in_host); + using ComputeType = std::conditional_t; + using AccDataType = + std::conditional_t, int32_t, float>; + const index_t num_accums = conv_param.K_; + // Calculate thresholds + auto rtol = ck::utils::get_relative_threshold( + num_accums / split_k_for_run); + auto atol = ck::utils::get_absolute_threshold( + max_accumulated_value / split_k_for_run, num_accums / split_k_for_run); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + split_k_for_run); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, split_k_for_run); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); + + pass = pass & ck::utils::check_err( + in_device, in_host, "Error: Incorrect results!", rtol, atol); + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; if(do_log) { @@ -225,35 +256,47 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification, copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); + std::vector split_k_list = {1, 2, 4, 8, 16, 32, 64, 128}; + + if(split_k > 0) + { + split_k_list = {split_k}; + } + for(auto& op_ptr : op_ptrs) { - auto argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - {}, - static_cast(in_device_buf.GetDeviceBuffer()), - out_lengths, - out_strides, - wei_lengths, - wei_strides, - {}, - {}, - in_lengths, - in_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - out_element_op, - wei_element_op, - in_element_op); + for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + {}, + static_cast(in_device_buf.GetDeviceBuffer()), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op, + split_k_list[split_k_id]); - run_impl(op_ptr, argument_ptr); + run_impl(op_ptr, argument_ptr, split_k_list[split_k_id]); + } } std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " + << best_split_k << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp new file mode 100644 index 0000000000..9d38263d4e --- /dev/null +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_relu.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_conv_fwd_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param) +{ + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor host_output(out_g_n_k_wos_desc); + Tensor device_output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weight.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weight.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weight.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + + // run reference op + if(do_verification) + { + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{}; + + std::array, 1> d_tensors = {bias}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weight, + host_output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + // init host output to zero + host_output.SetZero(); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_avg_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + bool pass = true; + + auto run_impl = [&](auto& op_ptr, auto& argument_ptr) { + // workspace_sz will be equal to 0 for other layout than NGCHW + const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init output to zero before profiling next kernel + out_device_buf.SetZero(); + + std::string op_name = op_ptr->GetTypeString(); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = pass & ck::utils::check_err(device_output, host_output); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + }; + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + AComputeType, + BComputeType>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {bias_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {e_g_n_k_wos_lengths}, + {e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + run_impl(op_ptr, argument_ptr); + } + + std::cout << "Best configuration parameters:" + << "\nname: " << best_op_name << "\navg_time: " << best_avg_time + << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 4bfbca5437..dfa6bc1edd 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index 09e03de99c..8fb20f0135 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 367e94de11..fc2ba5a650 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -1,11 +1,12 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp index 94ee2a37e4..1b17f05760 100644 --- a/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_multiply_tile_loop_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" diff --git a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp index 3a4ca24dda..cf3c3a6bae 100644 --- a/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp @@ -6,6 +6,7 @@ #include #include "ck/ck.hpp" +#include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 9cb70e4670..17c8c277eb 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -58,7 +58,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_SOURCES profile_gemm_universal_batched.cpp) @@ -76,6 +75,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12 if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) endif() + list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) @@ -144,7 +144,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) endif() target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_b_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_batched_instance) @@ -170,6 +169,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_gemm_multiply_multiply.cpp b/profiler/src/profile_gemm_multiply_multiply.cpp index ad2bb77544..42192b5985 100644 --- a/profiler/src/profile_gemm_multiply_multiply.cpp +++ b/profiler/src/profile_gemm_multiply_multiply.cpp @@ -42,7 +42,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[]) printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: " "f16->f8; 7: f8->bf16, " - "comp f8; 8: int8->bf16; 9: f8->f16, comp f8;)\n"); + "comp f8; 8: int8->bf16; 9: int8->f16, 10. f8->f16;)\n"); printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index a22d983da5..7f2393a7e6 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -103,8 +103,10 @@ int profile_gemm_universal(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) using F8 = ck::f8_t; +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using I4 = ck::pk_i4_t; #endif @@ -201,7 +203,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); @@ -210,6 +212,8 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); } +#endif +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN) { return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); diff --git a/profiler/src/profile_grouped_conv_bwd_data.cpp b/profiler/src/profile_grouped_conv_bwd_data.cpp index 1515f1105f..5cdece499e 100644 --- a/profiler/src/profile_grouped_conv_bwd_data.cpp +++ b/profiler/src/profile_grouped_conv_bwd_data.cpp @@ -68,8 +68,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) const bool time_kernel = std::stoi(argv[7]); const int num_dim_spatial = std::stoi(argv[8]); - // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial - if(argc != 8 + 1 + 4 + 6 * num_dim_spatial) + // 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K + if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1) { print_helper_msg(); return 1; @@ -77,6 +77,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv); + ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]); + using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -110,7 +112,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[]) OutDataType, WeiDataType, InDataType>( - do_verification, init_method, do_log, time_kernel, params); + do_verification, init_method, do_log, time_kernel, params, split_k); return pass ? 0 : 1; }; diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 9ee05d1304..a7714b4c73 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -114,17 +114,19 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using GKZYXC = ck::tensor_layout::convolution::GKZYXC; // using GKCX = ck::tensor_layout::convolution::GKXC; - using GKCYX = ck::tensor_layout::convolution::GKCYX; - // using GKCZYX = ck::tensor_layout::convolution::GKZYXC; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; using GNWK = ck::tensor_layout::convolution::GNWK; using GNHWK = ck::tensor_layout::convolution::GNHWK; using GNDHWK = ck::tensor_layout::convolution::GNDHWK; // - using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; - using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; // using NWGC = ck::tensor_layout::convolution::NWGC; @@ -366,6 +368,23 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); } } + // NGCDHW_GKCZYX_NGKDHW + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/script/convert_miopen_driver_to_profiler.py b/script/convert_miopen_driver_to_profiler.py index 1278b6744d..2ddcbb67cd 100644 --- a/script/convert_miopen_driver_to_profiler.py +++ b/script/convert_miopen_driver_to_profiler.py @@ -126,6 +126,8 @@ def run_ck_grouped_conv_bwd_data(args): args.ck_profier_op = "grouped_conv_bwd_data" parse_data_type(args) parse_layouts(args) + # Test all split K value from the list {1, 2, 4, 8, 32, 64, 128} + args.split_k_value = -1 cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)] cmd += [str(args.data_type), str(args.layout)] @@ -136,6 +138,7 @@ def run_ck_grouped_conv_bwd_data(args): cmd += [str(args.in_channels)] add_conv_params_to_cmd(args, cmd) + cmd += [str(args.split_k_value)] run_ck_profiler_cmd(cmd) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 38fbf5385f..69ffb94488 100755 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,7 @@ set(REGRESSION_TESTS test_batchnorm_bwd_rank_4 test_grouped_convnd_bwd_data_xdl test_conv_tensor_rearrange + test_gemm_mx ) function(add_test_executable TEST_NAME) @@ -101,11 +102,11 @@ function(add_test_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -197,13 +198,13 @@ function(add_gtest_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_mx") #only build mx example for gfx950 - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -254,6 +255,7 @@ add_subdirectory(reduce) add_subdirectory(convnd_fwd) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_convnd_fwd) +add_subdirectory(grouped_convnd_fwd_bias_relu) add_subdirectory(grouped_convnd_bwd_weight) add_subdirectory(block_to_ctile_map) add_subdirectory(softmax) @@ -279,6 +281,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9 endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_subdirectory(mx_mfma_op) + add_subdirectory(gemm_mx) endif() add_subdirectory(position_embedding) add_subdirectory(scatter_gather) diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp index d7c39367c8..1464eacfa5 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -9,6 +9,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" #include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp" +#include + using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::MaskingSpecialization; using ck::tensor_operation::device::TensorSpecialization; diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp new file mode 100644 index 0000000000..7d20ee4827 --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_device_utils.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +namespace ck { +namespace test { + +struct DeviceResources +{ + int computeUnits; + size_t totalMemory; + std::string deviceName; + // Add other relevant properties as needed +}; + +inline DeviceResources GetDeviceResources() +{ + DeviceResources res; + hipDeviceProp_t props; + + hipError_t status = hipGetDeviceProperties(&props, 0); + if(status != hipSuccess) + { + props.multiProcessorCount = 0; + res.computeUnits = 0; + res.totalMemory = 0; + res.deviceName = "Unknown"; + return res; + } + + res.computeUnits = props.multiProcessorCount; + res.totalMemory = props.totalGlobalMem; + res.deviceName = props.name; + + return res; +} + +// Device capability tiers +enum class DeviceCapabilityTier +{ + LOW, // Low resources devices (CU less than 80) + MEDIUM, // Mid-range devices + HIGH // High resources devices (CU hiher than 100) +}; + +inline DeviceCapabilityTier DetermineDeviceTier() +{ + DeviceResources res = GetDeviceResources(); + + // Adjust these thresholds based on your device specifics + if(res.computeUnits < 80) + { + return DeviceCapabilityTier::LOW; + } + else if(res.computeUnits < 100) + { + return DeviceCapabilityTier::MEDIUM; + } + else + { + return DeviceCapabilityTier::HIGH; + } +} + +} // namespace test +} // namespace ck diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp index 8136257a24..8d894576c4 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" +#include "test_batched_gemm_device_utils.hpp" template class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16 @@ -110,14 +111,45 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16_Irregul TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16) { - this->lengths_ = std::vector>{ - {256, 256, 64, 64, 48, 16}, - {256, 256, 128, 128, 48, 16}, - {512, 512, 64, 64, 48, 16}, - {512, 512, 128, 128, 48, 16}, - {1024, 1024, 64, 64, 48, 16}, - {1024, 1024, 128, 128, 48, 16}, - }; + + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + + // Configure test sizes based on device tier + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + // Minimal test sizes for low resource devices + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 16, 8}, {256, 256, 128, 128, 16, 8}, {512, 512, 64, 64, 8, 4}}; + std::cout << "Running reduced benchmarks for low-resource device" << std::endl; + } + else if(deviceTier == ck::test::DeviceCapabilityTier::MEDIUM) + { + // Medium test sizes + this->lengths_ = std::vector>{{256, 256, 64, 64, 24, 12}, + {256, 256, 128, 128, 24, 12}, + {512, 512, 64, 64, 16, 8}, + {512, 512, 128, 128, 16, 8}, + {1024, 1024, 64, 64, 8, 4}, + {1024, 1024, 128, 128, 8, 4}}; + std::cout << "Running medium benchmarks for mid-tier device" << std::endl; + } + else + { + // Full test sizes for high resource devices + this->lengths_ = std::vector>{{256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}}; + std::cout << "Running full benchmarks for high-performance device" << std::endl; + } + this->bench_ = true; this->verify_ = false; this->Run(); @@ -127,9 +159,20 @@ using ck::tensor_operation::device::GemmSpecialization; TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) { + + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + int P = 120; // requires padding int Q = 128; // do not require padding + // For lower-end devices, we might need to skip some tests + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + std::cout << "Skipping GemmSpecialization tests for low-resource device" << std::endl; + return; + } + // IsSupported(M, N, K, O) // clang-format off EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); @@ -153,15 +196,25 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) { - // IsSupported(M, N, K, O) - // clang-format off - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); - // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); - // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 - EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + EXPECT_FALSE( + DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{} + .IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKPadding>{} + .IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % + // ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % + // B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128< + GemmSpecialization::MNKOPadding>{} + .IsSupported(128, 128, 128, 129)); // clang-format on } diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp index 81d404109f..3a86736f44 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp @@ -3,6 +3,7 @@ #include "gtest/gtest.h" #include "test_batched_gemm_softmax_gemm_permute_util.hpp" +#include "test_batched_gemm_device_utils.hpp" template class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 @@ -132,9 +133,19 @@ using ck::tensor_operation::device::GemmSpecialization; TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) { + // Get device capability tier + auto deviceTier = ck::test::DetermineDeviceTier(); + int P = 120; // requires padding int Q = 128; // do not require padding + // For lower-end devices, we might need to skip some tests + if(deviceTier == ck::test::DeviceCapabilityTier::LOW) + { + std::cout << "Skipping GemmSpecialization tests for low-resource device" << std::endl; + return; + } + // IsSupported(M, N, K, O) // clang-format off EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 0af3ef3b34..4633f23ded 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -81,10 +81,13 @@ class TestCkTileBatchedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::BatchedGemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -138,11 +142,29 @@ class TestCkTileBatchedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 7701e451ad..3e7296b1eb 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -1,8 +1,28 @@ # Currently ck_tile is only built on gfx94/gfx95 +set(EXAMPLE_GEMM_COMPILE_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +set(EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS "") +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +list(APPEND EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS + -mllvm + -enable-noalias-to-md-conversion=0 +) + +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_gemm_pipeline_mem test_gemm_pipeline_mem.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) else() message("Skipping ck_tile_gemm tests for current target") endif() +endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index d81e870ffc..8944e6865d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV3 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV3 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV3, KernelTypesCompV3); #include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp index 1da0028f63..22e77fac41 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv4.cpp @@ -9,7 +9,7 @@ class TestCkTileGemmPipelineCompV4 : public TestCkTileGemmPipeline #define TEST_SUITE_NAME TestCkTileGemmPipelineCompV4 -TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesMem); +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV4, KernelTypesCompV4); #include "test_gemm_pipeline_ut_cases.inc" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 1b997ddbce..85742cb3de 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -63,6 +63,19 @@ struct GemmPipelineTypeSelector using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; }; +template +void try_run(ck_tile::TailNumber tn) +{ + if constexpr(Pipeline::PrefetchStages > static_cast(TN)) + { + if(tn == TN) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + } +} + template class TestCkTileGemmPipeline : public ::testing::Test { @@ -138,9 +151,12 @@ class TestCkTileGemmPipeline : public ::testing::Test const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); @@ -193,15 +210,32 @@ class TestCkTileGemmPipeline : public ::testing::Test s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if constexpr(PipelineType == GemmPipelineType::CompV3) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { @@ -219,86 +253,43 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail pipeline One to Seven if(tail_num == ck_tile::TailNumber::One) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } - if constexpr(BaseGemmPipeline::PrefetchStages > 2) - { - if(tail_num == ck_tile::TailNumber::Two) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 3) - { - if(tail_num == ck_tile::TailNumber::Three) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 4) - { - if(tail_num == ck_tile::TailNumber::Four) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 5) - { - if(tail_num == ck_tile::TailNumber::Five) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 6) - { - if(tail_num == ck_tile::TailNumber::Six) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } - if constexpr(BaseGemmPipeline::PrefetchStages > 7) - { - if(tail_num == ck_tile::TailNumber::Seven) - { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); - } - } + auto check_tail = [&](auto... TNs) { + (try_run(tail_num), ...); + }; + + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}); } if constexpr(PipelineType == GemmPipelineType::CompV4) { if(tail_num == ck_tile::TailNumber::Three) { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } else { - Run(ck_tile::bool_constant{}, - ck_tile::integral_constant{}); + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); } } } @@ -307,7 +298,8 @@ class TestCkTileGemmPipeline : public ::testing::Test // Tail number always Full - #PrefetchStages if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index b125d19762..3dec229643 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -102,10 +102,13 @@ class TestCkTileGroupedGemm : public ::testing::Test float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem>; + UniversalGemmProblem::TransposeC, + memory_operation>>; using Kernel = ck_tile::GroupedGemmKernel; auto kargs = Kernel::MakeKargs(gemm_descs); @@ -164,11 +168,29 @@ class TestCkTileGroupedGemm : public ::testing::Test return ave_time; }; + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + if(has_hot_loop) { if(tail_num == ck_tile::TailNumber::Full) { - Run(ck_tile::bool_constant{}, + RunSplitk( + ck_tile::bool_constant{}, ck_tile::integral_constant{}); } else diff --git a/test/data_type/test_bf8_ocp.cpp b/test/data_type/test_bf8_ocp.cpp index 9d4ee38b15..285e7e69fc 100644 --- a/test/data_type/test_bf8_ocp.cpp +++ b/test/data_type/test_bf8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" using ck::bf8_ocp_t; +using ck::bf8x2_ocp_t; +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic) const auto bf8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp32x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + float2_t f32x2 = type_convert(bf8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> bf8x2 + f32x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP32BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(bf8_ocp_t{bf8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -14.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -16.0f)); + + // fp32x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> fp16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + half2_t f16x2 = type_convert(bf8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> bf8x2 + f16x2 = {-4.0f, 2.0f}; + bf8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceFP16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // fp16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + uint8_t bf8_uid = static_cast(bf8_id); + auto v = type_convert(bf8_ocp_t{bf8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // bf8x2 -> bf16x2 + bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16 + + bhalf2_t bf16x2 = type_convert(bf8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> bf8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + bf8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + bf8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(bf8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(BF8OCP, HostBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_bf8_type_convert(test_size, out.data(), &completed); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_bf8_type_convert(N, p_test, p_completed); +} + +TEST(BF8OCP, DeviceBF16BF8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_bf8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set bf8_nan_ids; + bf8_nan_ids.insert(0b11111111); + bf8_nan_ids.insert(0b01111111); + bf8_nan_ids.insert(0b11111101); + bf8_nan_ids.insert(0b01111101); + bf8_nan_ids.insert(0b11111110); + bf8_nan_ids.insert(0b01111110); + for(auto bf8_nan_id : bf8_nan_ids) + { + auto idx = bf8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++) + { + if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end()) + continue; + + uint8_t bf8_uid = static_cast(bf8_id); + auto idx = bf8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(bf8_ocp_t{bf8_uid})) + << " bf8_id: " << bf8_id << std::endl + << type_convert(type_convert(bf8_ocp_t{bf8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // bf8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -14.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -16.0f))); + + // bf16x2 -> bf8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} diff --git a/test/data_type/test_fp8_ocp.cpp b/test/data_type/test_fp8_ocp.cpp index 944dd89930..bf562112c8 100644 --- a/test/data_type/test_fp8_ocp.cpp +++ b/test/data_type/test_fp8_ocp.cpp @@ -1,13 +1,19 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include "ck/library/utility/device_memory.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +using ck::bhalf2_t; +using ck::bhalf_t; using ck::f8_convert_rne; using ck::f8_convert_sr; using ck::f8_ocp_t; +using ck::f8x2_ocp_t; +using ck::float2_t; +using ck::half2_t; using ck::half_t; using ck::type_convert; @@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic) auto f8_nan = f8_convert_sr(ck::NumericLimits::QuietNaN()); ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data)); } + +constexpr uint64_t test_size = 256 + 6; + +__host__ __device__ void +test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp32x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + float2_t f32x2 = type_convert(fp8x2); + p_test[i++] = f32x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f32x2[1]; + if(i >= N) + { + return; + } + + // fp32x2 -> fp8x2 + f32x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f32x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp32_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed) +{ + test_fp32_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP32FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(float)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp32_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx]; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(f8_ocp_t{fp8_uid}); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp32x2 + EXPECT_EQ(out[i++], -powf(2.0f, -6.0f)); + EXPECT_EQ(out[i++], powf(2.0f, -9.0f)); + + // fp32x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + // SR + EXPECT_EQ(out[i++], -4.0f); + EXPECT_EQ(out[i++], 2.0f); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> fp16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + half2_t f16x2 = type_convert(fp8x2); + p_test[i++] = f16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = f16x2[1]; + if(i >= N) + { + return; + } + + // fp16x2 -> fp8x2 + f16x2 = {-4.0f, 2.0f}; + fp8x2 = f8_convert_rne(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(f16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_fp16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed) +{ + test_fp16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceFP16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(half_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_fp16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> fp16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // fp16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__host__ __device__ void +test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + if(p_completed == nullptr) + { + return; + } + + uint64_t& i = *p_completed; + i = 0; + + if(p_test == nullptr) + { + return; + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + uint8_t fp8_uid = static_cast(fp8_id); + auto v = type_convert(f8_ocp_t{fp8_uid}); + p_test[i] = v; + i++; + if(i >= N) + { + return; + } + } + + /// Test vector conversion + // fp8x2 -> bf16x2 + f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9 + + bhalf2_t bf16x2 = type_convert(fp8x2); + p_test[i++] = bf16x2[0]; + if(i >= N) + { + return; + } + p_test[i++] = bf16x2[1]; + if(i >= N) + { + return; + } + + // bf16x2 -> fp8x2 + bf16x2 = {type_convert(-4.0f), type_convert(2.0f)}; + fp8x2 = f8_convert_rne(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } + + fp8x2 = f8_convert_sr(bf16x2); // expect {-4, 2} + + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<0>{})); //-4f + if(i >= N) + { + return; + } + p_test[i++] = type_convert(fp8x2.AsType()(ck::Number<1>{})); // 2f + if(i >= N) + { + return; + } +} + +TEST(FP8OCP, HostBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + uint64_t completed = 0; + + test_bf16_fp8_type_convert(test_size, out.data(), &completed); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + // /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} + +__global__ void +device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed) +{ + test_bf16_fp8_type_convert(N, p_test, p_completed); +} + +TEST(FP8OCP, DeviceBF16FP8Convert) +{ + std::vector out(test_size, -1.0f); + + DeviceMem device_out(test_size * sizeof(bhalf_t)); + DeviceMem device_completed(sizeof(uint64_t)); + + device_out.SetValue(-21.0f); + device_completed.SetValue(-21.0f); + + device_test_bf16_fp8_type_convert<<<1, 1>>>( + test_size, + static_cast(device_out.GetDeviceBuffer()), + static_cast(device_completed.GetDeviceBuffer())); + + uint64_t completed = 0; + device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + std::set fp8_nan_ids; + fp8_nan_ids.insert(0b11111111); //-NaN + fp8_nan_ids.insert(0b01111111); // +NaN + for(auto fp8_nan_id : fp8_nan_ids) + { + auto idx = fp8_nan_id; + ASSERT_TRUE(std::isnan(type_convert(out[idx]))) + << "idx: " << idx << " out[idx]: " << type_convert(out[idx]); + } + + for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++) + { + if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end()) + continue; + + uint8_t fp8_uid = static_cast(fp8_id); + auto idx = fp8_uid; + ASSERT_FLOAT_EQ(out[idx], type_convert(f8_ocp_t{fp8_uid})) + << " fp8_id: " << fp8_id << std::endl + << type_convert(type_convert(f8_ocp_t{fp8_uid})); + } + + /// Test vector conversions + + auto i = 256; + + // fp8x2 -> bf16x2 + EXPECT_EQ(out[i++], type_convert(-powf(2.0f, -6.0f))); + EXPECT_EQ(out[i++], type_convert(powf(2.0f, -9.0f))); + + // bf16x2 -> fp8x2 + // RNE + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + // SR + EXPECT_EQ(out[i++], type_convert(-4.0f)); + EXPECT_EQ(out[i++], type_convert(2.0f)); + + EXPECT_EQ(test_size, completed); + EXPECT_EQ(test_size, i); +} diff --git a/test/data_type/test_mx_fp4.cpp b/test/data_type/test_mx_fp4.cpp index 449f6fc777..7aca42567c 100644 --- a/test/data_type/test_mx_fp4.cpp +++ b/test/data_type/test_mx_fp4.cpp @@ -240,6 +240,7 @@ TEST(MXFP4, HostScaledConvert) EXPECT_EQ(test_size, i); } +#if !CK_TEMP_DISABLE_FP4_TESTS __global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed) { test_mx_fp4_scaled_convert(N, p_test, p_completed); @@ -539,3 +540,4 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert) EXPECT_EQ(N, completed); EXPECT_EQ(N, i); } +#endif // CK_TEMP_DISABLE_FP4_TESTS diff --git a/test/gemm_mx/CMakeLists.txt b/test/gemm_mx/CMakeLists.txt new file mode 100644 index 0000000000..71a0a98f2d --- /dev/null +++ b/test/gemm_mx/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_mx test_gemm_mx.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_mx PRIVATE utility device_gemm_mx_instance) + endif() diff --git a/test/gemm_mx/test_gemm_mx.cpp b/test/gemm_mx/test_gemm_mx.cpp new file mode 100644 index 0000000000..2c976a217f --- /dev/null +++ b/test/gemm_mx/test_gemm_mx.cpp @@ -0,0 +1,277 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "test_gemm_mx_util.hpp" + +using E8M0 = ck::e8m0_bexp_t; +using F8 = ck::f8_t; +using BF8 = ck::bf8_t; +using F6 = ck::f6_t; +using BF6 = ck::bf6_t; +using F4 = ck::f4_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmMX_MK_NK + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + +template +class TestGemmMX_MK_KN + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + +template +class TestGemmMX_KM_NK + : public ck::test::TestGemmMX, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_F8_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< F8, F8, F16, ck::Number<32> >, + std::tuple< F8, F8, BF16, ck::Number<32> > +#endif + >; + +using KernelTypes_BF8_F8_MK_KN = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< BF8, F8, F16, ck::Number<32> > +#endif + >; + +using KernelTypes_F8_KM_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + // ADataType, BDataType, CDataType, ScaleBlockSize + std::tuple< F8, F8, BF16, ck::Number<32> > +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_F8_MK_NK); +TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_BF8_F8_MK_KN); +TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_F8_KM_NK); + +/// A: RowMajor +/// B: ColMajor +/// C: RowMajor + +TYPED_TEST(TestGemmMX_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_NK, Large) +{ + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + + constexpr int K = 4096; + constexpr int StrideA = K; + constexpr int StrideB = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +/// A: RowMajor +/// B: RowMajor +/// C: RowMajor + +TYPED_TEST(TestGemmMX_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 256; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_KN, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_MK_KN, Large) +{ + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + + constexpr int K = 4096; + constexpr int StrideA = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideB = N; + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +/// A: ColMajor +/// B: ColMajor +/// C: RowMajor + +TYPED_TEST(TestGemmMX_KM_NK, SmallN) +{ + constexpr int M = 256; + std::vector Ns{1, 2, 3, 4, 5, 6}; + constexpr int K = 512; + + constexpr int StrideA = M; + constexpr int StrideB = K; + + for(int N : Ns) + { + const auto new_N = N * 8; + const auto StrideC = new_N; + this->Run(M, new_N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmMX_KM_NK, MidLargeN) +{ + constexpr int M = 256; + std::vector Ns{127, 255, 312, 799, 1573}; + constexpr int K = 512; + + constexpr int StrideA = M; + constexpr int StrideB = K; + + for(int N : Ns) + { + const auto new_N = (N + 7) / 8 * 8; + const auto StrideC = new_N; + this->Run(M, new_N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmMX_KM_NK, Regular) +{ + std::vector Ms{3840}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, M, StrideB, StrideC); +} + +TYPED_TEST(TestGemmMX_KM_NK, Large) +{ + std::vector> test_sizes{{5120, 5120}, {3840, 5120}, {4096, 4096}}; + + constexpr int K = 4096; + constexpr int StrideB = K; + + for(auto test_size : test_sizes) + { + auto M = test_size.first; + auto N = test_size.second; + + const auto StrideA = M; + const auto StrideC = N; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} diff --git a/test/gemm_mx/test_gemm_mx_util.hpp b/test/gemm_mx/test_gemm_mx_util.hpp new file mode 100644 index 0000000000..02833daeb4 --- /dev/null +++ b/test/gemm_mx/test_gemm_mx_util.hpp @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/number.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +namespace ck { +namespace test { + +namespace { +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +} // namespace + +template +bool profile_gemm_mx_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + if(K % ScaleBlockSize != 0) + { + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); + }; + + using ScaleDataType = e8m0_bexp_t; + using AScaleLayout = Row; + using BScaleLayout = Col; + + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + auto f_get_default_stride = + [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor a_m_k_scale(f_host_tensor_descriptor( + M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A + Tensor b_k_n_scale(f_host_tensor_descriptor( + K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::size_t total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: // Initializations for development and debugging + ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); + if(do_log) + { + std::cout << "Init A = {1}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {0.5}" << std::endl; + std::cout << "Init B scale = {1.0}" << std::endl; + std::cout << "Expect C = {K}" << std::endl; + } + break; + + case 1: + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + b_k_n.GenerateTensorValue(GeneratorTensor_2{-4, 5}); // Z[-4,4] + + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + + break; + + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1] + + b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); + break; + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_log > 0) + std::cout << "Device memory allocation..." << std::endl; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a_scale_device_buf(sizeof(ScaleDataType) * a_m_k_scale.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem b_scale_device_buf(sizeof(ScaleDataType) * b_k_n_scale.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + if(do_log > 0) + std::cout << "Upload data to device..." << std::endl; + a_device_buf.ToDevice(a_m_k.mData.data()); + a_scale_device_buf.ToDevice(a_m_k_scale.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + b_scale_device_buf.ToDevice(b_k_n_scale.mData.data()); + + if(do_log > 0) + std::cout << "Done." << std::endl; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMX; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMXGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + a_m_k_scale, + b_k_n, + b_k_n_scale, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + std::optional best_op_object_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0 + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_log) + { + + if(init_method == 0) + { + auto expected = static_cast(K); + auto computed = type_convert(c_m_n_device_result(0, 12)); + + pass = pass & (std::abs(expected - computed) <= 0.0f); + std::cout << "\nExpected vs Computed: " << expected << " vs " + << computed << ((pass) ? " (PASSED!)" : " (FAILED!)") + << std::endl + << std::endl; + } + else + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "a_scale : ", a_m_k_scale.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b_scale: ", b_k_n_scale.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + } + + std::string op_name = op_ptr->GetTypeString(); + std::optional op_obj_name = op_ptr->GetObjectName(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + + // scaling of partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = + std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + + sizeof(ScaleDataType) * (M * K + K * N) / ScaleBlockSize; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops && ave_time > 1e-10) + { + best_op_name = op_name; + best_op_object_name = op_obj_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + if(best_op_object_name) + std::cout << best_op_object_name.value() << std::endl; + + return pass; +} + +template +class TestGemmMX : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + using ScaleType = e8m0_bexp_t; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = float; + + public: + static constexpr index_t ScaleBlockSize = std::tuple_element_t<5, Tuple>{}; + static constexpr bool verify_ = true; + static constexpr int init_method_ = 2; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::test::profile_gemm_mx_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt old mode 100644 new mode 100755 index 4aab6323cc..0a68622ebe --- a/test/gemm_universal/CMakeLists.txt +++ b/test/gemm_universal/CMakeLists.txt @@ -1,4 +1,29 @@ -add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) +add_gtest_executable(test_gemm_universal_wmma_fp16 test_gemm_universal_wmma_fp16.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) - endif() + target_link_libraries(test_gemm_universal_wmma_fp16 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_wmma_bf16 test_gemm_universal_wmma_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_wmma_bf16 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_wmma_fp8 test_gemm_universal_wmma_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_wmma_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp16 test_gemm_universal_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp16 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_fp8 test_gemm_universal_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_fp8 PRIVATE utility device_gemm_universal_instance) +endif() + +add_gtest_executable(test_gemm_universal_xdl_bf16 test_gemm_universal_xdl_bf16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_xdl_bf16 PRIVATE utility device_gemm_universal_instance) +endif() diff --git a/test/gemm_universal/test_gemm_universal_ut_cases.inc b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc similarity index 75% rename from test/gemm_universal/test_gemm_universal_ut_cases.inc rename to test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc index 9a21666856..8a6c672a9f 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_bf16.inc @@ -1,6 +1,6 @@ #pragma once -TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -14,7 +14,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -28,7 +28,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -44,7 +44,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, SmallM) { std::vector Ms{1, 2, 3, 4, 5, 6}; constexpr int N = 512; @@ -60,7 +60,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, SmallM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -74,7 +74,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -88,7 +88,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -104,7 +104,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; constexpr int N = 512; @@ -120,7 +120,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, MidLargeM) } } -TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -134,7 +134,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -148,7 +148,7 @@ TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_KN, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -164,7 +164,7 @@ TYPED_TEST(TestGemmUniversal_KM_KN, PaddK) } } -TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) +TYPED_TEST(TestGemmUniversal_BF16_KM_NK, PaddK) { std::vector Ms{127}; constexpr int N = 512; @@ -180,7 +180,7 @@ TYPED_TEST(TestGemmUniversal_KM_NK, PaddK) } } -TYPED_TEST(TestGemmUniversal_MK_KN, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_KN, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -194,7 +194,7 @@ TYPED_TEST(TestGemmUniversal_MK_KN, Regular) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_MK_NK, Regular) +TYPED_TEST(TestGemmUniversal_BF16_MK_NK, Regular) { std::vector Ms{512}; constexpr int N = 512; @@ -207,35 +207,3 @@ TYPED_TEST(TestGemmUniversal_MK_NK, Regular) for(int M : Ms) this->Run(M, N, K, StrideA, StrideB, StrideC); } - -TYPED_TEST(TestGemmUniversal_KM_KN, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} - -TYPED_TEST(TestGemmUniversal_KM_NK, Regular) -{ - std::vector Ms{512}; - constexpr int N = 512; - constexpr int K = 512; - - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - { - int StrideA = M; - this->Run(M, N, K, StrideA, StrideB, StrideC); - } -} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc new file mode 100644 index 0000000000..6f6d550625 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP16_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc new file mode 100644 index 0000000000..b831e15e9c --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp8.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_FP8_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp new file mode 100644 index 0000000000..22376a8599 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_bf16.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_BF16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_BF16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; + +using KernelTypes_KM_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< BF16, BF16, BF16, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); + +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp new file mode 100644 index 0000000000..1adee41ed2 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp16.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F16, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp new file mode 100644 index 0000000000..3579424496 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_wmma_fp8.cpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +#if CK_USE_WMMA_FP8 + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; + +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + std::tuple< F8, F8, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp8.inc" + +#endif // CK_USE_WMMA_FP8 diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp similarity index 61% rename from test/gemm_universal/test_gemm_universal_xdl.cpp rename to test/gemm_universal/test_gemm_universal_xdl_bf16.cpp index b872d7089a..8fde65657a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl_bf16.cpp @@ -7,8 +7,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "test_gemm_universal_util.hpp" -using F8 = ck::f8_t; -using F16 = ck::half_t; using BF16 = ck::bhalf_t; using F32 = float; @@ -29,25 +27,25 @@ struct tuple_concat, std::tuple> } // namespace template -class TestGemmUniversal_MK_KN +class TestGemmUniversal_BF16_MK_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_MK_NK +class TestGemmUniversal_BF16_MK_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_KN +class TestGemmUniversal_BF16_KM_KN : public ck::test::TestGemmUniversal, Tuple>::type> { }; template -class TestGemmUniversal_KM_NK +class TestGemmUniversal_BF16_KM_NK : public ck::test::TestGemmUniversal, Tuple>::type> { }; @@ -55,22 +53,12 @@ class TestGemmUniversal_KM_NK // clang-format off using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType - std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) - std::tuple< F16, F8, F16, F16>, - std::tuple< F8, F16, F16, F16>, - std::tuple< F8, F8, F8, BF16>, -#endif + std::tuple< BF16, BF16, BF16, BF16> >; @@ -86,9 +74,9 @@ using KernelTypes_KM_KN = ::testing::Types< // clang-format on -TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes_MK_KN); -TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes_MK_NK); -TYPED_TEST_SUITE(TestGemmUniversal_KM_KN, KernelTypes_KM_KN); -TYPED_TEST_SUITE(TestGemmUniversal_KM_NK, KernelTypes_KM_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_MK_NK, KernelTypes_MK_NK); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_KN, KernelTypes_KM_KN); +TYPED_TEST_SUITE(TestGemmUniversal_BF16_KM_NK, KernelTypes_KM_NK); -#include "test_gemm_universal_ut_cases.inc" +#include "test_gemm_universal_ut_cases_bf16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp new file mode 100644 index 0000000000..24f587daf6 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp16.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP16_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP16_KM_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + +#endif + std::tuple< F16, F16, F16, F16> + >; + +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP16_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_universal_ut_cases_fp16.inc" diff --git a/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp new file mode 100644 index 0000000000..e833ab7825 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl_fp8.cpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_FP8_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_FP8_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_KN = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, ComputeDataType, CDataType + +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F16, F8, F16, F16>, + std::tuple< F8, F16, F16, F16>, + std::tuple< F8, F8, F8, BF16>, +#endif + // Fallback test type when FP8 is not enabled + std::tuple< F16, F16, F16, F16> + >; + + +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_KN, KernelTypes_MK_KN); +TYPED_TEST_SUITE(TestGemmUniversal_FP8_MK_NK, KernelTypes_MK_NK); + + +#include "test_gemm_universal_ut_cases_fp8.inc" diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc index b6970c4233..22977866b5 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_bf16.inc @@ -44,34 +44,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, SmallM) } } -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_BF16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_BF16_KM_KN, MidLargeM) { std::vector Ms{127, 255, 312, 799, 1573}; diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc index b2fdfe8193..99c8e6d163 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp16.inc @@ -28,34 +28,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_FP16_MK_KN, PaddK) { std::vector Ms{127}; diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc index b3da08f703..b98ee92800 100755 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_ut_cases_fp8.inc @@ -28,34 +28,6 @@ TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, SmallM) this->Run(M, N, K, StrideA, StrideB, StrideC); } -TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = N; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - -TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_NK, MidLargeM) -{ - std::vector Ms{127, 255, 312, 799, 1573}; - constexpr int N = 512; - constexpr int K = 320; - - constexpr int StrideA = K; - constexpr int StrideB = K; - constexpr int StrideC = N; - - for(int M : Ms) - this->Run(M, N, K, StrideA, StrideB, StrideC); -} - TYPED_TEST(TestGemmUniversal_Streamk_FP8_MK_KN, PaddK) { std::vector Ms{127}; diff --git a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp index ef3509c0ca..805587a274 100644 --- a/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp +++ b/test/gemm_universal_streamk/test_gemm_universal_streamk_util.hpp @@ -44,9 +44,8 @@ class TestGemmUniversal_Streamk : public testing::Test void SetUp() override { - grid_size_list = {38, 114, 228}; // {38, 76, 114, 152, 190, 228, 266, 304, 342, 380}; - streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile - // Stream-K+ DP, // {0, 1, 2, 3, 4} + streamk_sel_list = {0, 1, 2}; // 0: Data Parallel (DP) mode (Stream-K OFF), 1: 1-tile + // Stream-K+ DP, // {0, 1, 2, 3, 4} // 2:2-tile Stream-K + DP } @@ -58,10 +57,9 @@ class TestGemmUniversal_Streamk : public testing::Test const int StrideC) { for(auto streamk_sel : streamk_sel_list) - for(auto grid_size : grid_size_list) - { - RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, grid_size); - } + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, streamk_sel, -1); + } } void RunSingle(const int M, diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 6d78da8db7..5c816da416 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -2,6 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_da if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) endif() +if(GPU_TARGETS MATCHES "gfx9") + add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp) + target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp) if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp index eb6083c521..c4404b95ba 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl.cpp @@ -21,26 +21,31 @@ class TestGroupedConvndBwdDataXdl : public ::testing::Test using InLayout = std::tuple_element_t<3, Tuple>; std::vector conv_params; + std::vector split_ks{1, 2}; template void Run() { EXPECT_FALSE(conv_params.empty()); bool pass = true; - for(auto& param : conv_params) + for(auto split_k : split_ks) { - pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param); + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } } EXPECT_TRUE(pass); } @@ -92,19 +97,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) this->conv_params.clear(); this->conv_params.push_back( - {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( - {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - // SplitN case - this->conv_params.push_back( - {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->template Run<2>(); } @@ -112,28 +114,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) { this->conv_params.clear(); this->conv_params.push_back( - {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( - {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - // SplitN case - this->conv_params.push_back({3, - 1, - 128, - 4, - 192, - {2, 2, 2}, - {2, 224, 224}, - {1, 224, 224}, - {1, 1, 1}, - {0, 0, 0}, - {0, 0, 0}}); + {3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->template Run<3>(); } diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp new file mode 100644 index 0000000000..73d793cc5f --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_large_cases.cpp @@ -0,0 +1,120 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include + +#include "profiler/profile_grouped_conv_bwd_data_impl.hpp" + +template +class TestGroupedConvndBwdDataXdl : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using OutLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using InLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + std::vector split_ks{1, 2}; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto split_k : split_ks) + { + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl +{ +}; + +template +class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D) +{ + this->conv_params.clear(); + // SplitN case + this->conv_params.push_back( + {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D) +{ + this->conv_params.clear(); + // SplitN case + this->conv_params.push_back({3, + 1, + 128, + 4, + 192, + {2, 2, 2}, + {2, 224, 224}, + {1, 224, 224}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + this->template Run<3>(); +} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp index 43b77641d1..1cf91df52c 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -77,7 +77,9 @@ using KernelTypes3d = ::testing::Types std::tuple, std::tuple, std::tuple, - std::tuple>; + std::tuple, + std::tuple, + std::tuple>; template class TestGroupedConvndFwd1d : public TestGroupedConvndFwd diff --git a/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt b/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt new file mode 100644 index 0000000000..680a92b19c --- /dev/null +++ b/test/grouped_convnd_fwd_bias_relu/CMakeLists.txt @@ -0,0 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_grouped_convnd_fwd_bias_relu test_grouped_convnd_fwd_bias_relu.cpp) + target_link_libraries(test_grouped_convnd_fwd_bias_relu PRIVATE utility device_grouped_conv2d_fwd_bias_relu_instance device_grouped_conv3d_fwd_bias_relu_instance) +endif() diff --git a/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp b/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp new file mode 100644 index 0000000000..c508235d9c --- /dev/null +++ b/test/grouped_convnd_fwd_bias_relu/test_grouped_convnd_fwd_bias_relu.cpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_bias_relu_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_relu_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types>; + +using KernelTypes3d = ::testing::Types>; + +template +class TestGroupedConvndFwd2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwd2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index f65e89bb82..fddb8288a6 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -6,6 +6,8 @@ #include "mx_mfma_op.hpp" using ck::e8m0_bexp_t; +using ck::f4_t; +using ck::f4x2_pk_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -16,7 +18,7 @@ using ck::type_convert; * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mfma_test(ck::index_t init) +bool run_mfma_km_kn_nm_test(ck::index_t init) { using ALayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -30,7 +32,8 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mfma_kernel = ck::matmul; + const auto mfma_kernel = ck:: + matmul; bool pass = true; @@ -52,15 +55,72 @@ bool run_mfma_test(ck::index_t init) TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) +{ + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); + EXPECT_TRUE(pass); +} + +/** + * @brief Run the test for the given MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mfma_mk_kn_mn_test(ck::index_t init) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using CPUAccType = AccType; + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + + const auto mfma_kernel = ck:: + matmul; + + bool pass = true; + + pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); + + return pass; +} + +TEST(MFMA, FP4MFMA16x16x128) { auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto pass = run_mfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = run_mfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } @@ -70,7 +130,7 @@ TEST(MFMA, FP8MFMA32x32x64) * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mxmfma_test(ck::index_t init) +bool run_mxmfma_mk_kn_mn_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, @@ -88,8 +148,18 @@ bool run_mxmfma_test(ck::index_t init) constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_X = 32; // scaling vector size - const auto mx_mfma_kernel = - ck::matmul; + const auto mx_mfma_kernel = ck::matmul; bool pass = true; @@ -111,14 +181,34 @@ bool run_mxmfma_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA16x16x128) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 1f9091ebc5..9ce871cfb1 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -30,48 +31,69 @@ enum class MFMA_F8F6F4 }; -template +template struct mfma_type_selector; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<16, 16> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); - } - - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, - BFragT const& fragB, - const int32_t scale_b, - AccumFragT& fragAcc) - { - auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<16, 16>(fragA, fragB, fragAcc); } }; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<32, 32> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); + op.template run<32, 32>(fragA, fragB, fragAcc); } +}; - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, +template +struct mfma_scale_type_selector; + +template <> +struct mfma_scale_type_selector<16, 16> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, BFragT const& fragB, - const int32_t scale_b, + BScaleFragT const& scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); + } +}; + +template <> +struct mfma_scale_type_selector<32, 32> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, + BFragT const& fragB, + BScaleFragT const& scale_b, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); } }; @@ -90,7 +112,7 @@ template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off - // Register Mapping for 16x128: || Register Mapping for 32x64: + // Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8: // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -155,13 +177,19 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + using ARawT = typename scalar_type::type; + using AScalarFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; AScalarFragT fragA{}; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + #pragma unroll - for(int chunk = 0; chunk < 2; chunk++) + for(int chunk = 0; chunk < num_chunks; chunk++) { #pragma unroll for(uint32_t i = 0; i < chunk_size; i++) @@ -220,6 +248,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -244,23 +294,34 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; // BLOCK_K is a stride in A matrix - auto startOffset = row_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + auto startOffset = row_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + row_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using ARawT = typename scalar_type::type; using AScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { AFragT frag; - AScalarFragT chunks[2]; + AScalarFragT chunks[num_chunks]; } fragA{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragA.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragA.chunks[1] = *fragPtr; + const AScalarFragT* fragPtr; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) + { + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragA.chunks[chunk_idx] = *fragPtr; + } return fragA.frag; } @@ -318,15 +379,35 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | | BLOCK_M | | BLOCK_M | | BLOCK_M | | || Size | BLOCK_M | | BLOCK_M | | | + // M | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || M | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(AFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where they start - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M)); // Col // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; @@ -334,8 +415,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // BLOCK_K / BLOCK_X is a stride in xA matrix auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); - // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_A_row_major(input_ptr); } @@ -349,7 +429,7 @@ template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off - // Register Mapping for 128x16: || Register Mapping for 64x32: + // Register Mapping for 128x16 for FP8: || Register Mapping for 64x32 for FP8: // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -386,6 +466,28 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -410,23 +512,34 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col // BLOCK_K is a stride in B matrix - auto startOffset = col_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + auto startOffset = col_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + col_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using BRawT = typename scalar_type::type; using BScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { BFragT frag; - BScalarFragT chunks[2]; + BScalarFragT chunks[num_chunks]; } fragB{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragB.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragB.chunks[1] = *fragPtr; + const BScalarFragT* fragPtr; + + for(index_t chunk = 0; chunk < num_chunks; chunk++) + { + fragPtr = + reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); + fragB.chunks[chunk] = *fragPtr; + } return fragB.frag; } @@ -486,15 +599,56 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | | BLOCK_N | | BLOCK_N | | BLOCK_N | | || Size | BLOCK_N | | BLOCK_N | | | + // N | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || N | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(0,N) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(0,N) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(0,N) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(0,N) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(0,N) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(0,N) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(0,N) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(0,N) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(0,N) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(0,N) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(0,N) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(0,N) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(0,N) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(0,N) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(0,N) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(0,N) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(BFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where to start - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N), // Row + threadIdx.x % BLOCK_N); // Col // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; @@ -502,7 +656,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_B_col_major(input_ptr); } @@ -746,15 +900,24 @@ template + int32_t BLOCK_K, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const BType* b, CType* c) { constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -766,22 +929,42 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) auto fragAcc = AccumFragT{0}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_A_col_major(a); - // B = col major, BLOCK_K x BLOCK_N - fragB = load_B_col_major(b); + if constexpr(is_same_v) + { + fragA = load_A_row_major(a); + } + else + { + fragA = load_A_col_major(a); + } + + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = load_B_col_major(b); + } // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}(fragA, fragB, fragAcc); + using mfma = mfma_type_selector; + mfma::template run<>(fragA, fragB, fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) { fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_col_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } template + int32_t BLOCK_X, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) { @@ -800,42 +986,73 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; - using ScaleFragT = int32_t; + using AScaleFragT = vector_type::type; + using BScaleFragT = vector_type::type; // Create frags auto fragA = AFragT{}; auto fragB = BFragT{}; auto fragC = CFragT{}; auto fragAcc = AccumFragT{0}; - auto fragXa = ScaleFragT{0}; - auto fragXb = ScaleFragT{0}; + auto fragXa = AScaleFragT{}; + auto fragXb = BScaleFragT{}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_mx_A_row_major( - a, xa, fragXa); + if constexpr(is_same_v) + { + fragA = + load_mx_A_row_major( + a, xa, fragXa); + } + else + { + printf("This layout is not implemented\n"); + } - // B = col major, BLOCK_K x BLOCK_N - fragB = load_mx_B_col_major( - b, xb, fragXb); + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = + load_mx_B_col_major( + b, xb, fragXb); + } // Scaled Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}( - fragA, fragXa, fragB, fragXb, fragAcc); + using mfma = mfma_scale_type_selector; + mfma::template run<>(fragA, + fragXa.template AsType(), + fragB, + fragXb.template AsType(), + fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) { fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_row_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } /** @@ -967,8 +1184,7 @@ struct TestMXMFMA { case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue( - GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); @@ -986,11 +1202,9 @@ struct TestMXMFMA a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; - case 3: // expect small round off errors a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); @@ -1000,6 +1214,14 @@ struct TestMXMFMA b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] @@ -1181,6 +1403,11 @@ struct TestMFMA a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + case 4: + // FP4 values case + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); diff --git a/tile_engine/CMakeLists.txt b/tile_engine/CMakeLists.txt new file mode 100755 index 0000000000..cd1a192a74 --- /dev/null +++ b/tile_engine/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(BEFORE + ${CMAKE_CURRENT_LIST_DIR}/include + ) + +add_subdirectory(ops) diff --git a/tile_engine/include/CMakeLists.txt b/tile_engine/include/CMakeLists.txt new file mode 100755 index 0000000000..d11a4b3bee --- /dev/null +++ b/tile_engine/include/CMakeLists.txt @@ -0,0 +1 @@ +message("Add include directory") diff --git a/tile_engine/ops/CMakeLists.txt b/tile_engine/ops/CMakeLists.txt new file mode 100755 index 0000000000..0cf2c16da2 --- /dev/null +++ b/tile_engine/ops/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(gemm) diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt new file mode 100644 index 0000000000..bc613a931e --- /dev/null +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -0,0 +1,51 @@ + + +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --list_blobs + RESULT_VARIABLE ret +) +set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS + ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}") +endif() + +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt GEMM_CODEGEN_BLOBS) + +add_custom_command( + OUTPUT ${GEMM_CODEGEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${CMAKE_CURRENT_BINARY_DIR} + --json ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json + --gen_blobs + DEPENDS ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + ${CMAKE_CURRENT_BINARY_DIR}/gemm_instance_blobs.txt + ${CMAKE_CURRENT_LIST_DIR}/configs/instance_combination.json +) + +set(EXECUTABLE_GEMM_INSTANCE "tile_engine_gemm") +message("adding example ${EXECUTABLE_GEMM_INSTANCE}") + +# use build as include directory +include_directories(${CMAKE_CURRENT_BINARY_DIR}) +add_executable(${EXECUTABLE_GEMM_INSTANCE} EXCLUDE_FROM_ALL gemm_host_api.cpp) +target_include_directories(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${GEMM_CODEGEN_BLOBS}) + +set(EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS) + +list(APPEND EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress) + +target_compile_options(${EXECUTABLE_GEMM_INSTANCE} PRIVATE ${EXECUTABLE_GEMM_INSTANCE_COMPILE_OPTIONS}) + +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) \ No newline at end of file diff --git a/tile_engine/ops/gemm/README.md b/tile_engine/ops/gemm/README.md new file mode 100644 index 0000000000..f7d86e90fe --- /dev/null +++ b/tile_engine/ops/gemm/README.md @@ -0,0 +1,92 @@ +# GEMM Matrix Multiplication + +CK Tile Engine GEMM is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues. + +# Kernel Configurations + +Kernel parameters are specified in the `instance_combination.json` file, including matrix layouts, data types, padding settings, pipelines, schedulers, epilogues, and numerical values for tile and warp sizes. + +Given a valid set of values, tile_engine_gemm will automatically iterate over all possible combinations of BlockTile and WarpTile sizes, as well as the specified pipelines, schedulers, and epilogues from `./configs/instance_combination.json`, and build the corresponding kernels. + + +## Build Instructions +``` bash +# in the root of composable kernel create build directory +mkdir build && cd build +# build composable kernel +sh ../script/cmake-ck-dev.sh ../ # replace with the appropriate architecture (example gfx942) or leave blank +# generate the executable +make tile_engine_gemm -j +``` +`tile_engine_gemm` will be located in the `./bin/` directory. + +_`tile_engine_gemm` must be rebuilt everytime `instance_combination.json` is modified._ +``` bash +rm -rf tile_engine/ && make tile_engine_gemm -j # rebuild +``` + +## tile_engine_gemm inputs +``` + + -m m dimension (default:3840) + -n n dimension (default:4096) + -k k dimension (default:2048) + -stride_a Tensor A stride (default:0) + -stride_b Tensor B stride (default:0) + -stride_c Tensor C stride (default:0) + -split_k SplitK value (default:1) + -v No validation: 0, Validation on CPU: 1, Validation on GPU: 2 (default:2) + -warmup Number of iterations before benchmark the kernel (default:50) + -repeat Number of iterations to benchmark the kernel (default:100) + -timer gpu:gpu timer, cpu:cpu timer (default:gpu) + -init Value for initializing tensor- random: 0, linear: 1, constant(1): 2 (default:0) +-structured_sparsity Sparsity for tensor - 0:false, 1:true (default: 0) + -pipeline possible values are: compv3, compv4, mem (default:compv3) + -scheduler possible values are: intrawave, interwave (default:intrawave) + -epilogue possible values are: cshuffle, default (default:cshuffle) + -pad_m Pad in m direction - true/false (default:false) + -pad_n Pad in n direction - true/false (default:false) + -pad_k Pad in k direction - true/false (default:false) + +Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in instance_combination.json +``` +Note: In `./configs/instance_combination.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above. + +## Example + +The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes. + +```json +{ + /// other parameters /// + + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [64, 32] + }, + + /// other parameters /// + + "pipeline": { + "values": ["compv3", "compv4", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["default", "cshuffle"] + } +} +``` + +At runtime, a specific subset of the generated kernels can be selected using command-line arguments. +``` bash +./bin/tile_engine_gemm -pipeline=compv3 -scheduler=intrawave -epilogue=default +``` +The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and default epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings. + diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json new file mode 100644 index 0000000000..53197ada6c --- /dev/null +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -0,0 +1,60 @@ +{ + + "layout_a": { + "values": ["r"] + }, + "layout_b": { + "values": ["c"] + }, + "layout_c": { + "values": ["r"] + }, + "datatype": { + "values": ["fp16"] + }, + "tile_m": { + "values": [256] + }, + "tile_n": { + "values": [256] + }, + "tile_k": { + "values": [32] + }, + "warp_m": { + "values": [2] + }, + "warp_n": { + "values": [2] + }, + "warp_k": { + "values": [1] + }, + "warp_tile_m": { + "values": [32] + }, + "warp_tile_n": { + "values": [32] + }, + "warp_tile_k": { + "values": [16] + }, + "kPadM": { + "values": [false] + }, + "kPadN": { + "values": [false] + }, + "kPadK": { + "values": [false] + }, + "pipeline": { + "values": ["compv3", "compv4", "mem"] + }, + "scheduler": { + "values": ["intrawave", "interwave"] + }, + "epilogue": { + "values": ["default", "cshuffle"] + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.cpp b/tile_engine/ops/gemm/gemm_host_api.cpp new file mode 100755 index 0000000000..a5447cd658 --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "gemm_common.hpp" +#include "gemm_dispatcher.hpp" +#include "gemm_host_api.hpp" + +void gemm_kernel_launch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, + bool structured_sparsity, + KernelTraits& trait, + ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) +{ + return GemmDispatcher::dispatch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + structured_sparsity, + trait, + args, + stream); +} + +template +void run(const ck_tile::ArgParser& arg_parser) +{ + const ALayout a_layout = ALayout{}; + const BLayout b_layout = BLayout{}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + int verify = arg_parser.get_int("v"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool structured_sparsity = arg_parser.get_bool("structured_sparsity"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillConstant{static_cast(1)}(a_m_k); + ck_tile::FillConstant{static_cast(1)}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(structured_sparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + // permute_tensor_b(b_k_n_dev); + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::GemmHostArgs gemm_args; + gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); + gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); + gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); + gemm_args.k_batch = kbatch; + gemm_args.M = M; + gemm_args.N = N; + gemm_args.K = K; + gemm_args.stride_A = stride_A; + gemm_args.stride_B = stride_B; + gemm_args.stride_C = stride_C; + + KernelTraits trait; + trait.pipeline = arg_parser.get_str("pipeline"); + trait.scheduler = arg_parser.get_str("scheduler"); + trait.epilogue = arg_parser.get_str("epilogue"); + trait.kPadM = arg_parser.get_bool("pad_m"); + trait.kPadN = arg_parser.get_bool("pad_n"); + trait.kPadK = arg_parser.get_bool("pad_k"); + + std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << std::endl; + + ck_tile::HostTensor c_m_n_host_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(verify) + { + gemm_host_reference(verify, + a_m_k, + b_k_n, + c_m_n_host_result, + a_m_k_dev_buf, + b_k_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C); + } + + gemm_kernel_launch(c_m_n_dev_buf, + c_m_n_host_result, + c_m_n_dev_result, + verify, + structured_sparsity, + trait, + gemm_args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + return; +} + +int main(int argc, char* argv[]) +{ + try + { + auto [result, parser] = create_args(argc, argv); + if(!result) + return EXIT_FAILURE; + run(parser); + return 0; + } + catch(const std::exception& e) + { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tile_engine/ops/gemm/gemm_host_api.hpp b/tile_engine/ops/gemm/gemm_host_api.hpp new file mode 100755 index 0000000000..579d2770db --- /dev/null +++ b/tile_engine/ops/gemm/gemm_host_api.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include "ck_tile/ops/gemm.hpp" + +#pragma once + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +/// @brief Defines the configuration parameters for a GEMM operation, enabling the selection of a +/// specific kernel instance based on the provided settings. +struct KernelTraits +{ + /// @brief The name of the pipeline. + std::string pipeline; + /// @brief The name of the scheduler (e.g., "intrawave", "interwave"). + std::string scheduler; + /// @brief The name of the epilogue (e.g., "cshuffle", "default"). + std::string epilogue; + /// @brief Indicates whether padding is applied to the M dimension. + bool kPadM; + /// @brief Indicates whether padding is applied to the N dimension. + bool kPadN; + /// @brief Indicates whether padding is applied to the K dimension. + bool kPadK; +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +inline auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("split_k", "1", "splitK value") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("structured_sparsity", "0", "0:false, 1:true") + .insert("pipeline", "compv3", "compv3, compv4, mem") + .insert("scheduler", "intrawave", "intrawave, interwave") + .insert("epilogue", "cshuffle", "cshuffle, default") + .insert("pad_m", "false", "true, false") + .insert("pad_n", "false", "true, false") + .insert("pad_k", "false", "true, false"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} + +/// @brief Function to compare the results of the device and host computations +void compare(ck_tile::index_t K, + ck_tile::index_t kbatch, + ck_tile::HostTensor& c_m_n_dev_result, + ck_tile::HostTensor& c_m_n_host_result) +{ + const float max_accumulated_value = + *std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_result, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; +} + +/// @brief Function to get the kernel output with reference implementation on CPU/GPU +template +void gemm_host_reference(int verify, + ck_tile::HostTensor& a_m_k, + ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C) +{ + if(verify == 1) + { + c_m_n_host_result.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_result); + } + else if(verify == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes()); + c_m_n_host_result.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data()); + } +} diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py new file mode 100755 index 0000000000..3839523e3d --- /dev/null +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -0,0 +1,657 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +from enum import IntEnum +from pathlib import Path +import sys +from typing import List, Optional, Dict, Any +import functools +import itertools +import copy +import json +from dataclasses import dataclass + +DATA_TYPE_MAP = {'fp32' : 'float', + 'fp16' : 'ck_tile::half_t', + 'bf16' : 'ck_tile::bf16_t', + 'int8' : 'ck_tile::int8_t', + 'fp8' : 'ck_tile::fp8_t', + 'bf8' : 'ck_tile::bf8_t', + 'int4' : 'ck_tile::pk_int4_t' + } + +LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + +DEFAULT_EPILOGUE = """ + using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; +""" + +CSHUFFLE_EPILOGUE = """ + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; +""" +HOT_LOOP_FALSE = """ + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("Num K loop must be larger than number of prefetech stages."); + } +""" +RUN_MEM = """ + // Handle One and Full cases directly + if (tail_num == ck_tile::TailNumber::One) { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } else if (tail_num == ck_tile::TailNumber::Full) { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + // Variadic call using fold expression + auto check_tail = [&](auto... TNs) { + (try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...); + }; + + check_tail( + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{}, + ck_tile::integral_constant{} + ); +""" + +RUN_COMPV3 = """ + if(tail_num == ck_tile::TailNumber::Full) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Odd) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else if(tail_num == ck_tile::TailNumber::Even) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even."); + } +""" + +RUN_COMPV4 = """ + if(tail_num == ck_tile::TailNumber::Three) + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } + else + { + RunSplitk(ck_tile::bool_constant{}, + ck_tile::integral_constant{}); + } +""" + + +PIPELINE_MAP = {'mem' : ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'], + 'compv3' : ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'], + 'compv4' : ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']} + +SCHEDULER_MAP = {'interwave' : 'ck_tile::GemmPipelineScheduler::Interwave', + 'intrawave' : 'ck_tile::GemmPipelineScheduler::Intrawave'} + +EPILOGUE_MAP = {'default' :DEFAULT_EPILOGUE, + 'cshuffle' : CSHUFFLE_EPILOGUE} + +HOT_LOOP_TRUE = {'mem' : RUN_MEM, + 'compv3' : RUN_COMPV3, + 'compv4' : RUN_COMPV4} + + +def BOOL_MAP(b_) -> str: + if b_: + return 'true' + else: + return 'false' + +@dataclass +class GemmConfig: + def __init__(self, config_data): + self.matrix_cfg : Dict[str, Any] = {} + self.impl_cfg : Dict[str, Any] = {} + for key, value in config_data.items(): + if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + self.matrix_cfg[key] = value + else: + self.impl_cfg[key] = value + + @property + def datatype(self) -> str: + return self.matrix_cfg["datatype"]["values"][0] + + @property + def layouts(self) -> List[str]: + return [ + self.matrix_cfg["layout_a"]["values"][0], + self.matrix_cfg["layout_b"]["values"][0], + self.matrix_cfg["layout_c"]["values"][0] + ] + + +class GemmCodeGenerator: + def __init__(self, output_dir: str, config: GemmConfig): + self.output_dir = Path(output_dir) + if not self.output_dir.exists(): + self.output_dir.mkdir() + + self.config = config + self.all_kernels = [] + self.unique_configs = [] + # Validate configurations + self._validate_config() + + def _validate_config(self): + """Validate matrix and implementation configurations""" + # Matrix config validation + for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + if len(self.config.matrix_cfg[param]["values"]) != 1: + raise ValueError(f"Matrix config {param} must have exactly one value") + + # Implementation traits validation + required_params = ["tile_m", "tile_n", "tile_k", "warp_m", "warp_n", "warp_k", + "warp_tile_m", "warp_tile_n", "warp_tile_k", "pipeline", + "epilogue", "scheduler", "kPadM", "kPadN", "kPadK"] + for param in required_params: + if not self.config.impl_cfg.get(param, {}).get("values"): + raise ValueError(f"Missing implementation parameter: {param}") + + def list_all(self): + """List all possible kernel configurations""" + w_p = Path(self.output_dir) + list_p = w_p / 'gemm_instance_blobs.txt' + self._list_config_groups() + with list_p.open('w') as list_f: + list_f.write(str(w_p / ("gemm_common.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_instances.hpp")) + "\n") + list_f.write(str(w_p / ("gemm_dispatcher.hpp")) + "\n") + for group in self.all_kernels: + list_f.write(str(w_p / ("gemm_" + group + ".hpp")) + "\n") + + + + def _list_config_groups(self): + params = [ + ("pipeline", "pipeline"), + ("epilogue", "epilogue"), + ("scheduler", "scheduler"), + ("kPadM", "kPadM"), + ("kPadN", "kPadN"), + ("kPadK", "kPadK") + ] + + # Generate all unique_combinations + _unique = set(itertools.product(*[self.config.impl_cfg[p]["values"] for (p, _) in params])) + for combo in _unique: + config = {name: value for (_, name), value in zip(params, combo)} + pipeline, epilogue, scheduler, kPadM, kPadN, kPadK = config.values() + # To remove some unsupported combinations + unsupported_combination = [("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave")] + if (pipeline, epilogue, scheduler) not in unsupported_combination: + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + self.all_kernels.append(group_name) + self.unique_configs.append(config) + + def generate_all(self): + self._generate_common_header() + self._generate_config_groups() + self._generate_dispatcher() + + + def _generate_common_header(self): + """Generate common header with datatypes and layout""" + self.ctype = self.config.datatype + self.atype = self.config.datatype + self.btype = self.config.datatype + if self.config.datatype in ['fp8', 'bf8']: + self.ctype = 'fp16' + elif self.config.datatype in ['int4']: + self.atype = 'fp16' + self.ctype = 'fp16' + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" + +// Data types +using ADataType = {DATA_TYPE_MAP[self.atype]}; +using BDataType = {DATA_TYPE_MAP[self.btype]}; +using AccDataType = float; +using CDataType = {DATA_TYPE_MAP[self.ctype]}; + +// Layout configurations +using ALayout = {LAYOUT_MAP[self.config.layouts[0]]}; +using BLayout = {LAYOUT_MAP[self.config.layouts[1]]}; +using CLayout = {LAYOUT_MAP[self.config.layouts[2]]}; +""" + + + (self.output_dir / "gemm_common.hpp").write_text(content) + + def _generate_config_groups(self): + """Generate implementation configuration groups""" + if not self.unique_configs: # Check if the list is empty + self._list_config_groups() + for config in self.unique_configs: + self._generate_config_group(**config) + self.generate_common_instances_header() + + + def _generate_config_group(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool): + """Generate a configuration group with all tile/warp combinations""" + group_name = f"{pipeline}_{epilogue}_{scheduler}_pad_{BOOL_MAP(kPadM)}_{BOOL_MAP(kPadN)}_{BOOL_MAP(kPadK)}" + filename = f"gemm_{group_name}.hpp" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_common.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/host.hpp" + +namespace {group_name} {{ +""" + # Add template struct with configuration + content += self._generate_kernel_struct(pipeline, epilogue, scheduler, kPadM, kPadN, kPadK) + + content += f"\n}} // namespace {group_name}\n" + (self.output_dir / filename).write_text(content) + + def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str, + kPadM: bool, kPadN: bool, kPadK: bool) -> str: + """Generate kernel struct template""" + return f""" +template +void try_run(ck_tile::TailNumber tn) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if (tn == TN) {{ + RunSplitk(ck_tile::bool_constant{{}}, + ck_tile::integral_constant{{}}); + }} + }} +}} +template +struct GemmKernel {{ + static constexpr bool kPadM = {BOOL_MAP(kPadM)}; + static constexpr bool kPadN = {BOOL_MAP(kPadN)}; + static constexpr bool kPadK = {BOOL_MAP(kPadK)}; + + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) {{ + static constexpr bool permuteA = false; + static constexpr bool permuteB = false; + static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; + static constexpr bool TransposeC = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence, + permuteA, + permuteB>; + + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = + ck_tile::TileGemmTraits; + + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}; + + const ck_tile::index_t k_grain = args.k_batch * TileK; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{ + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = {SCHEDULER_MAP[scheduler]}; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = {PIPELINE_MAP[pipeline][1]}; + {EPILOGUE_MAP[epilogue]} + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + {{ + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!"); + }} + + if(s.log_level_ > 0) + {{ + std::cout << "Launching kernel with args:" + << " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}" + << ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}" + << std::endl; + }} + + ave_time = ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + return ave_time; + + }}; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{ + if(args.k_batch == 1) {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} else {{ + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{{}}); + }} + }}; + + if(has_hot_loop) {{ + {HOT_LOOP_TRUE[pipeline]} + }} else {{ + {HOT_LOOP_FALSE} + }} + + return ave_time; + }} + + static std::string get_name() {{ + return std::string("GemmKernel +#include +#include + +struct GemmDispatcher { + static auto& get_kernel_map() { + // Use a static local variable + static std::unordered_map& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>> kernel_map; + return kernel_map; + } + + static void init(bool structured_sparsity) { + auto& kernel_map = get_kernel_map(); + if(!kernel_map.empty()) return; + \n""" + # Add tile/warp instantiations + tile_params = set(itertools.product( + self.config.impl_cfg["tile_m"]["values"], + self.config.impl_cfg["tile_n"]["values"], + self.config.impl_cfg["tile_k"]["values"], + self.config.impl_cfg["warp_m"]["values"], + self.config.impl_cfg["warp_n"]["values"], + self.config.impl_cfg["warp_k"]["values"], + self.config.impl_cfg["warp_tile_m"]["values"], + self.config.impl_cfg["warp_tile_n"]["values"], + self.config.impl_cfg["warp_tile_k"]["values"] + )) + + + for group in self.all_kernels: + content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, + const ck_tile::stream_config& stream) {{ + if(structured_sparsity){{ // SMFMA""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + content += f""" + }} else {{""" + for tile in tile_params: + # Check if we have valid tile/warp combinations + # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m + if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ + ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): + continue + content += f""" + run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + content += f""" + }} + }};\n""" + + content += """ } + + template + static void run_kernel(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) + { + float avg_time = Kernel::launch(args, stream); + std::string description = Kernel::get_name(); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::size_t flop = std::size_t(2) * args.M * args.N * args.K; + std::size_t num_byte = sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.N * args.K + sizeof(CDataType) * args.M * args.N; + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_byte / 1.E6 / avg_time; + + std::cout << "Performance for " << description << " : " << avg_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + if(verify) + compare(args.K, args.k_batch, c_m_n_dev_result, c_m_n_host_result); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + } + + static auto dispatch(ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::HostTensor& c_m_n_host_result, + ck_tile::HostTensor& c_m_n_dev_result, + int verify, bool structured_sparsity, const KernelTraits &trait, ck_tile::GemmHostArgs& gemm_args, + const ck_tile::stream_config& stream) { + init(structured_sparsity); + const std::string key = assemble_key(trait); + auto& kernel_map = get_kernel_map(); + if(auto it = kernel_map.find(key); it != kernel_map.end()) { + return it->second(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, gemm_args, stream); + } + throw std::runtime_error("No suitable kernel found: " + key); + } + +private: + static std::string assemble_key(const KernelTraits &trait) { + return std::string(trait.pipeline) + "_" + + trait.epilogue + "_" + + trait.scheduler + "_" + + "pad_" + + (trait.kPadM ? "true" : "false") + "_" + + (trait.kPadN ? "true" : "false") + "_" + + (trait.kPadK ? "true" : "false"); + } +}; + +""" + (self.output_dir / "gemm_dispatcher.hpp").write_text(content) + + +def do_list_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.list_all() + +def do_gen_blobs(args, gemm_config): + generator = GemmCodeGenerator(args.working_path, gemm_config) + generator.generate_all() + + + +def main(args): + # Read json file + with open(args.json, 'r') as json_file: + config_data = json.load(json_file) + + gemm_config = GemmConfig(config_data) + + if args.list_blobs: + do_list_blobs(args, gemm_config) + elif args.gen_blobs: + do_gen_blobs(args, gemm_config) + else: + # If neither was specified, either do nothing or default to gen_blobs + print("No mode specified (use --list_blobs or --gen_blobs). Generating by default...") + do_gen_blobs(args, gemm_config) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen API for CK gemm kernel", + ) + parser.add_argument( + "-w", "--working_path", default="./", required=False, help="the path where all the blobs are going to be generated" + ) + parser.add_argument( + "-j", "--json", required=True, help="Path to the json which contains the kernel configurations" + ) + parser.add_argument( + "-l", "--list_blobs", action = 'store_true', help="List all kernel to file" + ) + parser.add_argument( + "-g", "--gen_blobs", action = 'store_true', help="Generate all kernels into different files" + ) + + args = parser.parse_args() + + main(args)