diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b798e38f3..e5903f3747 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,8 +72,9 @@ message(STATUS "Build with HIP ${HIP_VERSION}") rocm_create_package( - NAME CK-${CK_BACKEND} + NAME composablekernel DESCRIPTION "High Performance Composable Kernel for AMD GPUs" + MAINTAINER "MIOpen Kernels Dev Team " LDCONFIG ) @@ -226,15 +227,12 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin) -configure_file("${PROJECT_SOURCE_DIR}/include/ck/hip_version.hpp.in" "${PROJECT_BINARY_DIR}/include/ck/hip_version.hpp") - include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include ${PROJECT_BINARY_DIR}/include ${PROJECT_SOURCE_DIR}/library/include ) -include(googletest) SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV") if(BUILD_DEV) @@ -243,7 +241,31 @@ if(BUILD_DEV) endif() message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) + add_subdirectory(library) add_subdirectory(example) add_subdirectory(test) add_subdirectory(profiler) + +#Create an interface target for the include only files and call it "composablekernels" +include(CMakePackageConfigHelpers) + +set(version 1.0.0) +write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + VERSION "${version}" + COMPATIBILITY AnyNewerVersion +) + +configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel + NO_CHECK_REQUIRED_COMPONENTS_MACRO +) + +install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/Config.cmake.in b/Config.cmake.in new file mode 100644 index 0000000000..12b5c331ae --- /dev/null +++ b/Config.cmake.in @@ -0,0 +1,11 @@ +@PACKAGE_INIT@ + +set(_composable_kernel_supported_components device_operations host_tensor) + +foreach(_comp ${composable_kernel_FIND_COMPONENTS}) + if(NOT _comp IN_LIST _composable_kernel_supported_components) + set(composable_kernel_FOUND False) + set(composable_kernel_NOT_FOUND_MESSAGE "Unsupported component: ${_comp}") + endif() + include("${CMAKE_CURRENT_LIST_DIR}/composable_kernel${_comp}Targets.cmake") +endforeach() diff --git a/Dockerfile b/Dockerfile index c4cf0fac57..79c961144a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,13 +11,7 @@ ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ RUN apt-get update RUN apt-get install -y wget gnupg RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - -RUN if ! [ -z $OSDB_BKC_VERSION ]; then \ - echo "Using BKC VERISION: $OSDB_BKC_VERSION";\ - sh -c "echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-osdb-deb/ compute-rocm-dkms-no-npi-hipclang ${OSDB_BKC_VERSION} > /etc/apt/sources.list.d/rocm.list" ;\ - cat /etc/apt/sources.list.d/rocm.list;\ - else \ - sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" ;\ - fi +RUN sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list" RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add - RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/apt/sources.list" @@ -25,18 +19,15 @@ RUN sh -c "echo deb https://apt.kitware.com/ubuntu/ bionic main | tee -a /etc/ap # Install dependencies RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ apt-utils \ - sshpass \ build-essential \ cmake-data=3.15.1-0kitware1 \ cmake=3.15.1-0kitware1 \ curl \ - doxygen \ g++ \ gdb \ git \ hip-rocclr \ jq \ - lcov \ libelf-dev \ libncurses5-dev \ libnuma-dev \ @@ -44,7 +35,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- llvm-amdgpu \ pkg-config \ python \ - python3 \ + python3.8 \ python-dev \ python3-dev \ python-pip \ @@ -62,8 +53,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* -# RUN pip3 install --default-timeout=100000 -r requirements.txt - # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -83,6 +72,13 @@ ARG PREFIX=/opt/rocm RUN cget install pfultz2/rocm-recipes # Install rbuild RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/6d78a0553babdaea8d2da5de15cbda7e869594b8.tar.gz +# Install packages for processing the performance results +RUN pip3 install --upgrade pip +RUN pip3 install sqlalchemy +RUN pip3 install pymysql +RUN pip3 install pandas +RUN pip3 install setuptools-rust +RUN pip3 install sshtunnel # Setup ubsan environment to printstacktrace ENV UBSAN_OPTIONS=print_stacktrace=1 @@ -92,5 +88,3 @@ ADD rbuild.ini /rbuild.ini ADD dev-requirements.txt dev-requirements.txt RUN rbuild prepare -s develop -d $PREFIX RUN groupadd -f render -# RUN cget install -f min-requirements.txt -# RUN CXXFLAGS='-isystem $PREFIX/include' cget install -f ./mlir-requirements.txt diff --git a/Jenkinsfile b/Jenkinsfile index 824437c970..65876ea1c0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -7,7 +7,6 @@ def show_node_info() { echo "NODE_NAME = \$NODE_NAME" lsb_release -sd uname -r - cat /sys/module/amdgpu/version ls /opt/ -la """ } @@ -100,35 +99,45 @@ def buildHipClangJob(Map conf=[:]){ def variant = env.STAGE_NAME - def retimage - gitStatusWrapper(credentialsId: '7126e5fe-eb51-4576-b52b-9aaf1de8f0fd', gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { - try { - retimage = docker.build("${image}", dockerArgs + '.') - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES') - { - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + if (params.USE_DOCKERFILE){ + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } } } } - catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ - echo "The job was cancelled or aborted" - throw e - } - catch(Exception ex) { - retimage = docker.build("${image}", dockerArgs + "--no-cache .") - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES') - { - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' - } + else{ + timeout(time: 3, unit: 'HOURS'){ + retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull() + image="b56f8ac0d6ea" + sh "docker images" } } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 5, unit: 'HOURS') { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' cmake_build(conf) } } @@ -140,6 +149,10 @@ def reboot(){ build job: 'reboot-slaves', propagate: false , parameters: [string(name: 'server', value: "${env.NODE_NAME}"),] } + + + + def buildHipClangJobAndReboot(Map conf=[:]){ try{ buildHipClangJob(conf) @@ -156,14 +169,157 @@ def buildHipClangJobAndReboot(Map conf=[:]){ } } + +def runCKProfiler(Map conf=[:]){ + show_node_info() + + env.HSA_ENABLE_SDMA=0 + checkout scm + + def image = "composable_kernels" + def prefixpath = conf.get("prefixpath", "/opt/rocm") + def gpu_arch = conf.get("gpu_arch", "gfx908") + + // Jenkins is complaining about the render group + // def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" + if (conf.get("enforce_xnack_on", false)) { + dockerOpts = dockerOpts + " --env HSA_XNACK=1" + } + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg GPU_ARCH='${gpu_arch}' " + + def variant = env.STAGE_NAME + + def retimage + + gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel') { + if (params.USE_DOCKERFILE){ + try { + retimage = docker.build("${image}", dockerArgs + '.') + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + catch (org.jenkinsci.plugins.workflow.steps.FlowInterruptedException e){ + echo "The job was cancelled or aborted" + throw e + } + catch(Exception ex) { + retimage = docker.build("${image}", dockerArgs + "--no-cache .") + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 5, unit: 'MINUTES') + { + sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo' + } + } + } + } + else{ + timeout(time: 3, unit: 'HOURS'){ + retimage = docker.image('compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:9110_ubuntu18.04_py3.6_pytorch_rocm5.0_internal_testing_7ff5b54').pull() + image="b56f8ac0d6ea" + sh "docker images" + } + } + + withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { + timeout(time: 5, unit: 'HOURS') + { + cmake_build(conf) + dir("script"){ + //run gemm performance tests + def gemm_log = "perf_gemm_${gpu_arch}.log" + sh "rm -f ${gemm_log}" + sh "echo Branch name: ${env.BRANCH_NAME} > ${gemm_log}" + sh "echo Node name: ${NODE_NAME} >> ${gemm_log}" + sh "echo GPU_arch name: ${gpu_arch} >> ${gemm_log}" + sh "rocminfo | grep 'Compute Unit:' >> ${gemm_log} " + sh "hipcc --version | grep -e 'HIP version' >> ${gemm_log}" + sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log}" + sh "./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a ${gemm_log}" + sh "./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a ${gemm_log}" + //results will be parsed, stored, and analyzed within the python script + //the script will return 0 if the performance criteria are met + //or return 1 if the criteria are not met + archiveArtifacts "${gemm_log}" + sh "python3 parse_perf_data.py ${gemm_log} " + //run resnet50 test + def resnet_log = "perf_resnet50_${gpu_arch}.log" + sh "rm -f ${resnet_log}" + sh "echo Branch name: ${env.BRANCH_NAME} > ${resnet_log}" + sh "echo Node name: ${NODE_NAME} >> ${resnet_log}" + sh "echo GPU_arch name: ${gpu_arch} >> ${resnet_log}" + sh "rocminfo | grep 'Compute Unit:' >> ${resnet_log} " + sh "hipcc --version | grep -e 'HIP version' >> ${resnet_log}" + sh "/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log}" + //first run tests with N=256 + sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log}" + //then run with N=4 + sh "./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log}" + archiveArtifacts "${resnet_log}" + //the script will put the results from N=256 and N=4 runs into separate tables + sh "python3 parse_perf_data.py ${resnet_log} " + } + } + } + } + return retimage +} + + +def runPerfTest(Map conf=[:]){ + try{ + runCKProfiler(conf) + } + catch(e){ + echo "throwing error exception in performance tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + if (!conf.get("no_reboot", false)) { + reboot() + } + } +} + pipeline { agent none options { parallelsAlwaysFailFast() } - // environment{ - // variable = value - // } + parameters { + booleanParam( + name: "USE_DOCKERFILE", + defaultValue: true, + description: "") + } + environment{ + dbuser = "${dbuser}" + dbpassword = "${dbpassword}" + dbsship = "${dbsship}" + dbsshport = "${dbsshport}" + dbsshuser = "${dbsshuser}" + dbsshpassword = "${dbsshpassword}" + status_wrapper_creds = "${status_wrapper_creds}" + } stages{ stage("Static checks") { parallel{ @@ -178,29 +334,6 @@ pipeline { // buildHipClangJobAndReboot(build_cmd: build_cmd, no_reboot:true, prefixpath: '/opt/rocm', build_type: 'debug') // } // } - stage('Build Profiler: Release, gfx908') - { - agent { label rocmnode("nogpu")} - environment{ - setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ - } - steps{ - buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') - } - } - stage('Build Profiler: Debug, gfx908') - { - agent { label rocmnode("nogpu")} - environment{ - setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ - } - steps{ - // until we stabilize debug build due to compiler crashes - catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE') { - buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Debug') - } - } - } stage('Clang Format') { agent{ label rocmnode("nogpu") } environment{ @@ -220,7 +353,7 @@ pipeline { } } } - stage("Tests") + stage("Tests") { parallel { @@ -228,12 +361,11 @@ pipeline { { agent{ label rocmnode("gfx908")} environment{ - setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + setup_args = """ -D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ } steps{ - buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908") } - } stage("Run Tests: gfx90a") { @@ -242,26 +374,68 @@ pipeline { setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ } steps{ - buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release') + buildHipClangJobAndReboot(setup_args:setup_args, config_targets: "check", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a") } - } - } } - // enable after the cmake file supports packaging - // stage("Packages") { - // when { - // expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA } - // } - // parallel { - // stage("Package /opt/rocm") { - // agent{ label rocmnode("nogpu") } - // steps{ - // buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a") - // } - // } - // } - // } + stage("Client App") + { + parallel + { + stage("Run Client App") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " """ + execute_args = """ cd ../test/client_app && rm -rf build && mkdir build && cd build && cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" .. && make """ + } + steps{ + buildHipClangJobAndReboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + } + } + stage("Performance Tests") + { + parallel + { + stage("Run ckProfiler: gfx908") + { + agent{ label rocmnode("gfx908")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx908 -O3 " -DBUILD_DEV=On """ + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx908") + } + } + stage("Run ckProfiler: gfx90a") + { + agent{ label rocmnode("gfx90a")} + environment{ + setup_args = """ -D CMAKE_CXX_FLAGS="--offload-arch=gfx90a -O3 " -DBUILD_DEV=On """ + } + steps{ + runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release', gpu_arch: "gfx90a") + } + } + } + } + /* enable after the cmake file supports packaging + stage("Packages") { + when { + expression { params.BUILD_PACKAGES && params.TARGET_NOGPU && params.DATATYPE_NA } + } + parallel { + stage("Package /opt/rocm") { + agent{ label rocmnode("nogpu") } + steps{ + buildHipClangJobAndReboot( package_build: "true", prefixpath: '/opt/rocm', gpu_arch: "gfx906;gfx908;gfx90a") + } + } + } + } + */ } } diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..2fe9a8455e --- /dev/null +++ b/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2018- , Advanced Micro Devices, Inc. (Chao Liu, Jing Zhang) +Copyright (c) 2019- , Advanced Micro Devices, Inc. (Letao Qin, Qianfeng Zhang, Liang Huang, Shaojie Wang) +Copyright (c) 2022- , Advanced Micro Devices, Inc. (Anthony Chang, Chunyu Lai, Illia Silin, Adam Osewski, Poyen Chen, Jehandad Khan) +Copyright (c) 2019-2021, Advanced Micro Devices, Inc. (Hanwen Chang) +Copyright (c) 2019-2020, Advanced Micro Devices, Inc. (Tejash Shah) +Copyright (c) 2020 , Advanced Micro Devices, Inc. (Xiaoyan Zhou) +Copyright (c) 2021-2022, Advanced Micro Devices, Inc. (Jianfeng Yan) + +SPDX-License-Identifier: MIT +Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 4011d34415..f6c933bf5b 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ docker run \ --group-add sudo \ -w /root/workspace \ -v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +rocm/tensorflow:rocm5.1-tf2.6-dev \ /bin/bash ``` @@ -20,7 +20,7 @@ mkdir build && cd build cmake \ -D BUILD_DEV=OFF \ -D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3 \ +-D CMAKE_CXX_FLAGS=" --offload-arch=gfx908 --offload-arch=gfx90a -O3" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ .. @@ -43,3 +43,13 @@ Instructions for running each individual examples are under ```example/``` make -j ckProfiler ``` Instructions for running ckProfiler are under ```profiler/``` + + +## Caveat +### Kernel Timing and Verification +CK's own kernel timer will warn up kernel once, and then run it multiple times +to get average kernel time. For some kernels that use atomic add, this will cause +output buffer to be accumulated multiple times, causing verfication failure. +To work around it, do not use CK's own timer and do verification at the same time. +CK's own timer and verification in each example and ckProfiler can be enabled or +disabled from command line. diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 9f193b2090..78133af031 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -66,7 +66,7 @@ else() -Wunreachable-code -Wunused - -Wno-sign-compare + -Wsign-compare -Wno-extra-semi-stmt ) if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "Clang") diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index c7e70cc8a9..959bc4f4b0 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -18,6 +18,8 @@ list(APPEND GTEST_CMAKE_CXX_FLAGS -Wno-switch-enum -Wno-zero-as-null-pointer-constant -Wno-unused-member-function + -Wno-comma + -Wno-old-style-cast ) message(STATUS "Suppressing googltest warnings with flags: ${GTEST_CMAKE_CXX_FLAGS}") @@ -33,4 +35,5 @@ FetchContent_MakeAvailable(googletest) target_compile_options(gtest PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) - +target_compile_options(gmock PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) +target_compile_options(gmock_main PRIVATE ${GTEST_CMAKE_CXX_FLAGS}) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 696d3bac42..c03c454c68 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,3 +1,8 @@ +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp new file mode 100644 index 0000000000..9a22628777 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -0,0 +1,209 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000..32b183a3a1 --- /dev/null +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -0,0 +1,208 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = float; +using BDataType = float; +using CDataType = float; +using AccDataType = float; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000..16c9213104 --- /dev/null +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -0,0 +1,206 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = int8_t; +using BDataType = int8_t; +using CDataType = int8_t; +using AccDataType = int32_t; + +using ALayout = Col; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(1); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + bool pass = true; + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + pass = ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + } + + return pass ? 0 : 1; +} diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index a4567dcd6e..b126736be6 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -84,13 +84,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + ReferenceGemm; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; @@ -105,13 +105,13 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -125,7 +125,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); exit(0); } @@ -193,12 +193,12 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -232,7 +232,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 9cd40c3976..7e1af4bab2 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -4,7 +4,6 @@ #include #include #include - #include "check_err.hpp" #include "config.hpp" #include "device.hpp" @@ -29,29 +28,30 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // clang-format off #if 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; #elif 0 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_ProducerConsumer_CShuffle //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -70,13 +70,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + ReferenceGemm; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; @@ -87,17 +87,21 @@ int main(int argc, char* argv[]) ck::index_t StrideB = 4096; ck::index_t StrideC = 4096; - if(argc == 4) + if(argc == 1) + { + // use default case + } + else if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -111,7 +115,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); exit(0); } @@ -184,12 +188,12 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -214,7 +218,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp new file mode 100644 index 0000000000..7cea68c8b0 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -0,0 +1,238 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using ADataType = double; +using BDataType = double; +using CDataType = double; +using AccDataType = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl +//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if 0 + < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 4, 1, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>; +#else + < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; +#endif + // clang-format on + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) +{ + os << "[" << std::endl; + for(int x = 0; x < matrix.mDesc.GetLengths()[0]; x++) + { + os << "["; + for(int y = 0; y < matrix.mDesc.GetLengths()[1]; y++) + { + os << std::setw(4) << static_cast(matrix(x, y)); + } + os << "]" << std::endl; + } + os << "]"; + return os; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "data type: " << typeid(ADataType{}).name() << std::endl; + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + +#if 0 + { + show_2d_matrix(std::cout << "a : ", a_m_k) << std::endl; + show_2d_matrix(std::cout << "b: ", b_k_n) << std::endl; + show_2d_matrix(std::cout << "c_device: ", c_m_n_device_result) << std::endl; + show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; + } +#endif + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index ab5869db61..27fcd62a2c 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -78,14 +78,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; @@ -100,13 +105,13 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -120,7 +125,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); exit(0); } @@ -189,12 +194,12 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -219,7 +224,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp index 2abebbbac4..1a6e1de4dc 100644 --- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp +++ b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -86,9 +86,9 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, alpha, beta\n"); exit(0); } @@ -216,7 +216,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -246,6 +246,8 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } + + return 0; } diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp index f3ed2bad37..f91f6ccfc7 100644 --- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp +++ b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp @@ -3,89 +3,109 @@ #include #include #include -#include #include "check_err.hpp" #include "config.hpp" -#include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_gemm.hpp" #include "device_tensor.hpp" #include "element_wise_operation.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" -#include "reference_gemm_bias_activation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" template using S = ck::Sequence; -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; +using F16 = ck::half_t; +using F32 = float; -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::AddRelu; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on +// C = A * B +// E = Relu(C + D); +struct AddRelu +{ + __host__ __device__ void + operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d) const + { + const ck::half_t x = c + d; -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation; + e = x > 0 ? x : 0; + } +}; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddRelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; @@ -94,19 +114,23 @@ int main(int argc, char* argv[]) ck::index_t StrideA = 4096; ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideE = 4096; - if(argc == 4) + if(argc == 1) + { + // use default case + } + else if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -114,14 +138,14 @@ int main(int argc, char* argv[]) StrideA = std::stoi(argv[7]); StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideE = std::stoi(argv[9]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideE\n"); exit(0); } @@ -141,17 +165,14 @@ int main(int argc, char* argv[]) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); + Tensor d_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; switch(init_method) { @@ -159,59 +180,59 @@ int main(int argc, char* argv[]) case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + DeviceMem d_m_n_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpace()); + DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); + d_m_n_device_buf.ToDevice(d_m_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; // do GEMM - auto gemm = DeviceGemmInstance{}; + auto device_op = DeviceOpInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op); + auto invoker = device_op.MakeInvoker(); - if(!gemm.IsSupportedArgument(argument)) + auto argument = + device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), + b_k_n_device_buf.GetDeviceBuffer(), + std::array{d_m_n_device_buf.GetDeviceBuffer()}, + e_m_n_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{0}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + throw std::runtime_error("wrong! this device_op instance does not support this problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(EDataType) * M * N + sizeof(EDataType) * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -220,18 +241,38 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - if(do_verification) { + e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + Tensor c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; } + + return 0; } diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000..754de47c2b --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/README.md b/example/04_gemm_add_add_fastgelu/README.md new file mode 100644 index 0000000000..08a55fb9a3 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/README.md @@ -0,0 +1,23 @@ +# Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16``` + +## Run ```example_gemm_add_add_fastgelu_xdl_fp16``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE" +./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +d0_m_n: dim 2, lengths {3840, 4096}, strides {0, 1} +d1_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +e_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.26914 ms, 101.525 TFlops, 100.804 GB/s, DeviceGemmMultipleD_Xdl_CShuffle<256, 256, 128, 32, 8, 8> +``` diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp new file mode 100644 index 0000000000..7db5be0c91 --- /dev/null +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_fp16.cpp @@ -0,0 +1,245 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle +//######| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = 4096; + ck::index_t StrideE = 4096; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 12) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD0 = std::stoi(argv[9]); + StrideD1 = std::stoi(argv[10]); + StrideE = std::stoi(argv[11]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, " + "StrideE\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_m_n_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_m_k_device_buf.GetDeviceBuffer(), + b_k_n_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + e_m_n_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(D0DataType) * N + sizeof(D1DataType) * M * N + + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << device_op.GetTypeString() << std::endl; + + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + + e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/04_gemm_bias_relu_add/CMakeLists.txt b/example/04_gemm_bias_relu_add/CMakeLists.txt deleted file mode 100644 index 4f48db94a8..0000000000 --- a/example/04_gemm_bias_relu_add/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_gemm_xdl_bias_relu_add gemm_xdl_bias_relu_add.cpp) diff --git a/example/04_gemm_bias_relu_add/README.md b/example/04_gemm_bias_relu_add/README.md deleted file mode 100644 index f8d9bd6152..0000000000 --- a/example/04_gemm_bias_relu_add/README.md +++ /dev/null @@ -1,28 +0,0 @@ -# Instructions for ```example_gemm_xdl_bias_relu_add``` - -## Run ```example_gemm_xdl_bias_relu_add``` -```bash -#arg1: verification (0=no, 1=yes) -#arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC -./bin/example_gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 -``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} -arg.a_grid_desc_k0_m_k1_{512, 3840, 8} -arg.b_grid_desc_k0_n_k1_{512, 4096, 8} -arg.c_grid_desc_m_n_{ 3840, 4096} -arg.c0_grid_desc_m_n_{ 3840, 4096} -arg.c1_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s -``` diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp deleted file mode 100644 index 9405c36881..0000000000 --- a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ /dev/null @@ -1,255 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "element_wise_operation.hpp" -#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" -#include "reference_gemm_bias_activation_add.hpp" - -template -using S = ck::Sequence; - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl -// clang-format on - -using ReferenceGemmInstance = - ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; -int main(int argc, char* argv[]) -{ - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - ck::index_t StrideC1 = 4096; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 11) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - StrideC1 = std::stoi(argv[10]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - // c0_n[n] - Tensor c0_n(HostTensorDescriptor( - std::vector({static_cast(N)}), std::vector({1}))); - - // c1_m_n[m ,n] - Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_n: " << c0_n.mDesc << std::endl; - std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - } - - DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); - DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); - - a_m_k_device_buf.ToDevice(a_m_k.mData.data()); - b_k_n_device_buf.ToDevice(b_k_n.mData.data()); - c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_n_device_buf.ToDevice(c0_n.mData.data()); - c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), - static_cast(b_k_n_device_buf.GetDeviceBuffer()), - static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(c0_n_device_buf.GetDeviceBuffer()), - static_cast(c1_m_n_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - StrideC1, - a_element_op, - b_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, nrepeat); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + - sizeof(CDataType) * M * N + sizeof(CDataType) * N + - sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - - if(do_verification) - { - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - b_k_n, - c_m_n_host_result, - c0_n, - c1_m_n, - a_element_op, - b_element_op, - c_element_op); - - ref_invoker.Run(ref_argument); - - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); - } -} diff --git a/example/06_conv2d_fwd_bias_relu/CMakeLists.txt b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt index df8f70606c..4e1dd1f3e6 100644 --- a/example/06_conv2d_fwd_bias_relu/CMakeLists.txt +++ b/example/06_conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,2 +1,2 @@ add_example_executable(example_conv2d_fwd_xdl_bias_relu conv2d_fwd_xdl_bias_relu.cpp) -target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_fwd_util) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu PRIVATE conv_util) diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index 751ce16b90..d50afb6854 100644 --- a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -7,7 +7,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" #include "device_tensor.hpp" @@ -93,7 +93,7 @@ void PrintUseMsg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "Following arguments:\n" << " N, K, C, \n" << " , (ie Y, X for 2D)\n" @@ -120,40 +120,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) ck::utils::conv::ConvParams params; int arg_idx = 4; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -165,9 +165,9 @@ int main(int argc, char* argv[]) { using namespace ck::utils::conv; - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; const int num_dim_spatial = 2; ck::utils::conv::ConvParams params; @@ -176,7 +176,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } if(argc >= 5) @@ -184,21 +184,21 @@ int main(int argc, char* argv[]) params = ParseConvParams(argc, argv); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -211,7 +211,7 @@ int main(int argc, char* argv[]) get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); // bias: assume contiguous 1d vector Tensor bias( - HostTensorDescriptor(std::vector({static_cast(params.K)}))); + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -248,16 +248,16 @@ int main(int argc, char* argv[]) static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), static_cast(bias_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -269,18 +269,18 @@ int main(int argc, char* argv[]) "not support this problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = - get_btype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths) + - sizeof(OutDataType) * (params.K); + sizeof(OutDataType) * (params.K_); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -296,16 +296,17 @@ int main(int argc, char* argv[]) weights, host_output, bias, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - ck::utils::check_err( - host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; } + + return 0; } diff --git a/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt index 8bc5980025..b4dd39d83a 100644 --- a/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/example/07_conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,2 +1,3 @@ -add_example_executable(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) -target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_fwd_util) +# FIXME: should fix validation failure +add_example_executable_no_testing(example_conv2d_fwd_xdl_bias_relu_add conv2d_fwd_xdl_bias_relu_add.cpp) +target_link_libraries(example_conv2d_fwd_xdl_bias_relu_add PRIVATE conv_util) diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index e6339fcd23..1a234ea851 100644 --- a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -7,7 +7,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "device_tensor.hpp" @@ -90,7 +90,7 @@ void PrintUseMsg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "Following arguments:\n" << " N, K, C, \n" << " , (ie Y, X for 2D)\n" @@ -117,40 +117,40 @@ ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) ck::utils::conv::ConvParams params; int arg_idx = 4; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -162,9 +162,9 @@ int main(int argc, char* argv[]) { using namespace ck::utils::conv; - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; const int num_dim_spatial = 2; ck::utils::conv::ConvParams params; @@ -173,7 +173,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } if(argc >= 5) @@ -181,21 +181,21 @@ int main(int argc, char* argv[]) params = ParseConvParams(argc, argv); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -209,7 +209,7 @@ int main(int argc, char* argv[]) // bias: assume contiguous 1d vector Tensor bias( - HostTensorDescriptor(std::vector({static_cast(params.K)}))); + HostTensorDescriptor(std::vector({static_cast(params.K_)}))); // residual: assume same layout as output tensor Tensor residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); @@ -224,10 +224,10 @@ int main(int argc, char* argv[]) { case 0: break; case 1: - input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - residual.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + weights.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + bias.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + residual.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -259,16 +259,16 @@ int main(int argc, char* argv[]) static_cast(out_device_buf.GetDeviceBuffer()), static_cast(bias_device_buf.GetDeviceBuffer()), static_cast(resi_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, in_element_op, wei_element_op, out_element_op); @@ -280,20 +280,20 @@ int main(int argc, char* argv[]) "not support this problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = - get_btype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths) + - sizeof(OutDataType) * (params.K) + + sizeof(OutDataType) * (params.K_) + sizeof(OutDataType) * - (params.N * params.K * output_spatial_lengths[0] * output_spatial_lengths[1]); + (params.N_ * params.K_ * output_spatial_lengths[0] * output_spatial_lengths[1]); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -310,17 +310,18 @@ int main(int argc, char* argv[]) host_output, bias, residual, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, in_element_op, wei_element_op, out_element_op); ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - ck::utils::check_err( - host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err(device_output.mData, host_output.mData) ? 0 : 1; } + + return 0; } diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index f602862a04..1724e51f3f 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,6 +1,9 @@ -add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) -target_link_libraries(example_convnd_fwd_xdl PRIVATE conv_fwd_util) +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) -target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_fwd_util) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) -target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_fwd_util) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util) +target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util) +target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util) +target_link_libraries(example_convnd_fwd_xdl_fp16 PRIVATE conv_util) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index eaa5683978..d951bc4e4b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -5,7 +5,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_tensor.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" @@ -43,10 +43,10 @@ template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off - InDataType, // + InDataType, // WeiDataType, // OutDataType, // - AccDataType, // + AccDataType, // InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation @@ -110,7 +110,7 @@ void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "arg4: N spatial dimensions (default 2)\n" << "Following arguments (depending on number of spatial dims):\n" << " N, K, C, \n" @@ -137,40 +137,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha ck::utils::conv::ConvParams params; int arg_idx = 5; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -182,9 +182,9 @@ int main(int argc, char* argv[]) { using namespace ck::utils::conv; - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; int num_dim_spatial = 2; ck::utils::conv::ConvParams params; @@ -193,7 +193,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); num_dim_spatial = std::stoi(argv[4]); } @@ -202,21 +202,21 @@ int main(int argc, char* argv[]) params = parse_conv_params(num_dim_spatial, argc, argv); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -256,16 +256,16 @@ int main(int argc, char* argv[]) conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -277,22 +277,22 @@ int main(int argc, char* argv[]) "not support this Conv problem"); } - float ave_time = invoker->Run(argument.get(), nrepeat); + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = get_btype( - params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv->GetTypeString() << std::endl; if(do_verification) { @@ -302,40 +302,38 @@ int main(int argc, char* argv[]) auto ref_argument = ref_conv.MakeArgument(input, weights, host_output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - ck::utils::check_err( - host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1; }; switch(num_dim_spatial) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); } } } + return 0; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp similarity index 79% rename from example/09_convnd_fwd/convnd_fwd_xdl.cpp rename to example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp index e8895b8639..7fa0f0d275 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp @@ -5,7 +5,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_tensor.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" @@ -39,10 +39,10 @@ template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off - InDataType, // + InDataType, // WeiDataType, // OutDataType, // - AccDataType, // + AccDataType, // InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation @@ -107,7 +107,7 @@ void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "arg4: N spatial dimensions (default 2)\n" << "Following arguments (depending on number of spatial dims):\n" << " N, K, C, \n" @@ -134,40 +134,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha ck::utils::conv::ConvParams params; int arg_idx = 5; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -179,9 +179,9 @@ int main(int argc, char* argv[]) { using namespace ck::utils::conv; - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; int num_dim_spatial = 2; ck::utils::conv::ConvParams params; @@ -190,7 +190,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); num_dim_spatial = std::stoi(argv[4]); } @@ -199,21 +199,21 @@ int main(int argc, char* argv[]) params = parse_conv_params(num_dim_spatial, argc, argv); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -255,16 +255,16 @@ int main(int argc, char* argv[]) conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -276,16 +276,16 @@ int main(int argc, char* argv[]) "not support this Conv problem"); } - float ave_time = invoker->Run(argument.get(), nrepeat); + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = - get_btype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -301,40 +301,43 @@ int main(int argc, char* argv[]) auto ref_argument = ref_conv.MakeArgument(input, weights, host_output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - ck::utils::check_err( - host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err(device_output.mData, + host_output.mData, + "Error: incorrect results!", + 1e-5f, + 1e-4f) + ? 0 + : 1; }; switch(num_dim_spatial) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); } } } + return 0; } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp new file mode 100644 index 0000000000..52440e0d5f --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp64.cpp @@ -0,0 +1,344 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = double; +using WeiDataType = double; +using OutDataType = double; +using AccDataType = double; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 2, // K1 + 16, // MPerXDL + 16, // NPerXDL + 4, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 2, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 2, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = 0; + int init_method = 0; + bool time_kernel = false; + int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = parse_conv_params(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_1{1}); + weights.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = + get_btype(params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 34b4645770..9a1028f88b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -5,7 +5,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_tensor.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" @@ -45,10 +45,10 @@ template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off - InDataType, // + InDataType, // WeiDataType, // OutDataType, // - AccDataType, // + AccDataType, // InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation @@ -112,7 +112,7 @@ void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "arg4: N spatial dimensions (default 2)\n" << "Following arguments (depending on number of spatial dims):\n" << " N, K, C, \n" @@ -139,40 +139,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, cha ck::utils::conv::ConvParams params; int arg_idx = 5; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -184,9 +184,9 @@ int main(int argc, char* argv[]) { using namespace ck::utils::conv; - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; int num_dim_spatial = 2; ck::utils::conv::ConvParams params; @@ -195,7 +195,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); num_dim_spatial = std::stoi(argv[4]); } @@ -204,21 +204,21 @@ int main(int argc, char* argv[]) params = parse_conv_params(num_dim_spatial, argc, argv); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -258,16 +258,16 @@ int main(int argc, char* argv[]) conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -279,16 +279,16 @@ int main(int argc, char* argv[]) "not support this Conv problem"); } - float ave_time = invoker->Run(argument.get(), nrepeat); + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = get_btype( - params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -304,40 +304,38 @@ int main(int argc, char* argv[]) auto ref_argument = ref_conv.MakeArgument(input, weights, host_output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - ck::utils::check_err( - host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + return ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f) ? 0 : 1; }; switch(num_dim_spatial) { case 3: { auto ref_conv = ReferenceConvNDFwdInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvNDFwdInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); } } } + return 0; } diff --git a/example/10_conv2d_bwd_data/CMakeLists.txt b/example/10_conv2d_bwd_data/CMakeLists.txt index f300bc9645..17aca1481b 100644 --- a/example/10_conv2d_bwd_data/CMakeLists.txt +++ b/example/10_conv2d_bwd_data/CMakeLists.txt @@ -1,2 +1,2 @@ add_example_executable(example_conv2d_bwd_data_xdl conv2d_bwd_data_xdl.cpp) -target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_fwd_util) +target_link_libraries(example_conv2d_bwd_data_xdl PRIVATE conv_util) diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index f3f9b497f5..2d25f5ac2f 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -77,9 +77,9 @@ using ReferenceConvBwdInstance = ck::tensor_operation::host::ReferenceConvBwdDat int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // Conv shape ck::index_t N = 128; @@ -102,13 +102,13 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 19) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); N = std::stoi(argv[4]); K = std::stoi(argv[5]); @@ -130,7 +130,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(0); @@ -214,7 +214,7 @@ int main(int argc, char* argv[]) "not support this Conv problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; @@ -249,6 +249,10 @@ int main(int argc, char* argv[]) in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData); + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; } + return 0; } diff --git a/example/11_conv2d_bwd_weight/CMakeLists.txt b/example/11_conv2d_bwd_weight/CMakeLists.txt index ff001eab72..3d771b5569 100644 --- a/example/11_conv2d_bwd_weight/CMakeLists.txt +++ b/example/11_conv2d_bwd_weight/CMakeLists.txt @@ -1,2 +1,2 @@ add_example_executable(example_conv2d_bwd_weight_xdl conv2d_bwd_weight_xdl.cpp) -target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_fwd_util) +target_link_libraries(example_conv2d_bwd_weight_xdl PRIVATE conv_util) diff --git a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp index bf78cc87e0..1578161116 100644 --- a/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp +++ b/example/11_conv2d_bwd_weight/conv2d_bwd_weight_xdl.cpp @@ -82,9 +82,9 @@ using ReferenceConvBwdWeightInstance = int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; int do_log = 0; int split_k = 4; @@ -109,7 +109,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); do_log = std::stoi(argv[4]); split_k = std::stoi(argv[5]); } @@ -117,7 +117,7 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); do_log = std::stoi(argv[4]); split_k = std::stoi(argv[5]); @@ -141,7 +141,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: is show log (0=no, 1=yes)\n"); printf("arg5: split-k \n"); printf("arg6 to 19: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " @@ -246,7 +246,7 @@ int main(int argc, char* argv[]) return 1; } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; @@ -291,6 +291,9 @@ int main(int argc, char* argv[]) LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") << std::endl; } - ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData); + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData) + ? 0 + : 1; } + return 0; } diff --git a/example/12_reduce/CMakeLists.txt b/example/12_reduce/CMakeLists.txt index 734c1955d6..9045a78a85 100644 --- a/example/12_reduce/CMakeLists.txt +++ b/example/12_reduce/CMakeLists.txt @@ -1 +1,2 @@ add_example_executable(example_reduce_blockwise reduce_blockwise.cpp) +add_example_executable(example_reduce_blockwise_two_call reduce_blockwise_two_call.cpp) diff --git a/example/12_reduce/README.md b/example/12_reduce/README.md index 6fd3b3dcf3..826d2f6c33 100644 --- a/example/12_reduce/README.md +++ b/example/12_reduce/README.md @@ -5,23 +5,37 @@ # -D : input 4-d tensor lengths # -v : verification (0=no, 1=yes) #arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) -#arg2: run kernel # of times (>1) -./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10 +#arg2: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 ``` Result ``` -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 3 times... -Perf: 0.23536 ms, 267.32 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> -error: 0 -max_diff: 0, 529, 529 -root@dc-smc-18:/data/composable_kernel/Build3# bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 10 -launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} -Warm up +./bin/example_reduce_blockwise -D 16,64,32,960 -v 1 1 1 +launch_and_time_kernel: grid_dim {240, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time Start running 10 times... -Perf: 0.23392 ms, 268.966 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> -error: 0 -max_diff: 0, 528, 528 +Perf: 0.282592 ms, 222.641 GB/s, DeviceReduceBlockWise<256,M_C4_S1,K_C64_S1,InSrcVectorDim_0_InSrcVectorSize_1_OutDstVectorSize_1> +``` + +# Instructions for ```example_reduce_blockwise_two_call``` + +## Run ```example_reduce_blockwise_two_call``` +```bash +#arg1: verification (0=no, 1=yes( +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +./bin/example_reduce_blockwise_two_call 1 2 1 +``` + +Result +``` +./bin/example_reduce_blockwise_two_call 1 2 1 +launch_and_time_kernel: grid_dim {204800, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +launch_and_time_kernel: grid_dim {6400, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 2.1791 ms, 771.42 GB/s, DeviceReduceBlockWise<256,M_C32_S1,K_C8_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> => DeviceReduceBlockWise<256,M_C256_S1,K_C1_S1,InSrcVectorDim_1_InSrcVectorSize_1_OutDstVectorSize_1> ``` diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 293b593902..66e9762314 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -12,8 +12,8 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "device_base.hpp" -#include "device_reduce_blockwise.hpp" -#include "host_reduce_util.hpp" +#include "device_reduce_multiblock.hpp" +#include "host_common_util.hpp" #include "host_reduction.hpp" #include "reduction_enums.hpp" @@ -30,96 +30,53 @@ constexpr int Rank = 4; constexpr int NumReduceDim = 3; constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; -constexpr NanPropagation NanOpt = NanPropagation::PROPAGATE_NAN; -constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; -constexpr ReduceTensorIndices IndicesOpt = ReduceTensorIndices::NO_INDICES; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; -using ReduceOperation = typename reduce_binary_operator::opType; +using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator::AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; -using DeviceReduceInstance = DeviceReduceBlockWise; +using DeviceReduceInstance = DeviceReduceMultiBlock; static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, - {"scales", required_argument, nullptr, 'S'}, {"verify", required_argument, nullptr, 'v'}, {"help", no_argument, nullptr, '?'}, {nullptr, 0, nullptr, 0}}; class SimpleAppArgs { - template - static T getSingleValueFromString(const std::string& valueStr) - { - std::istringstream iss(valueStr); - - T ret; - - iss >> ret; - - return (ret); - }; - - template - static std::vector getTypeValuesFromString(const char* cstr_values) - { - std::string valuesStr(cstr_values); - - std::vector values; - std::size_t pos = 0; - std::size_t new_pos; - - new_pos = valuesStr.find(',', pos); - while(new_pos != std::string::npos) - { - const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); - - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - pos = new_pos + 1; - new_pos = valuesStr.find(',', pos); - }; - - std::string sliceStr = valuesStr.substr(pos); - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - return (values); - }; - private: int option_index = 0; public: - std::vector inLengths; - std::vector scales; + std::vector inLengths = {16, 64, 32, 960}; + std::vector scales = {1.0f, 0.0f}; - bool do_verification = false; - - int init_method = 1; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; public: void show_usage(const char* cmd) @@ -127,24 +84,24 @@ class SimpleAppArgs std::cout << "Usage of " << cmd << std::endl; std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" << std::endl; - std::cout << "--scales or -S, comma separated two float values for alpha and beta" - << std::endl; std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " "comparing with the host-based reduction" << std::endl; std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " "value, 3=decimal value)" << std::endl; - std::cout << "Arg2 -- number of repeats to run the kernel" << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; }; int processArgs(int argc, char* argv[]) { - unsigned int ch; + using ck::host_common::getTypeValuesFromString; + + int ch; while(1) { - ch = getopt_long(argc, argv, "D:S:v:l:", long_options, &option_index); + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); if(ch == -1) break; switch(ch) @@ -155,12 +112,6 @@ class SimpleAppArgs inLengths = getTypeValuesFromString(optarg); break; - case 'S': - if(!optarg) - throw std::runtime_error("Invalid option format!"); - - scales = getTypeValuesFromString(optarg); - break; case 'v': if(!optarg) throw std::runtime_error("Invalid option format!"); @@ -182,7 +133,7 @@ class SimpleAppArgs throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); init_method = std::atoi(argv[optind++]); - nrepeat = std::atoi(argv[optind]); + time_kernel = static_cast(std::atoi(argv[optind])); if(scales.empty()) { @@ -196,23 +147,21 @@ class SimpleAppArgs int main(int argc, char* argv[]) { - using namespace ck::host_reduce; - const std::vector reduceDims{0, 1, 2}; const std::vector invariantDims{3}; SimpleAppArgs args; - if(args.processArgs(argc, argv) < 0) - return (-1); + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; constexpr bool op_support_indices = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = - (op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES)); - // if input is half type, no reason to use float for indiced reduction operation and must use // float for non-indiced reduction operation for accuracy constexpr bool invalid_reduce_1 = @@ -226,8 +175,7 @@ int main(int argc, char* argv[]) (op_support_indices && !std::is_same::value); // indices option can only be used when it is really needed - constexpr bool invalid_reduce_3 = - (!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES); + constexpr bool invalid_reduce_3 = (!op_support_indices && OutputIndex); constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3); @@ -295,51 +243,65 @@ int main(int argc, char* argv[]) if(beta != 0.0f) out_dev.ToDevice(out.mData.data()); - size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int32_t) : 0; - DeviceMem out_indices_dev(indicesSizeInBytes); + DeviceMem out_index_dev(indicesSizeInBytes); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); if(args.do_verification) { ReductionHost + OutputIndex> hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - hostReduce.Run( - alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); }; - const auto i_inLengths = to_int_vector(args.inLengths); - const auto i_inStrides = to_int_vector(inStrides); - const auto i_outLengths = to_int_vector(outLengths); - const auto i_outStrides = to_int_vector(outStrides); + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); auto reduce = DeviceReduceInstance{}; - auto wsSizeInBytes = reduce.GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - auto argument_ptr = - reduce.MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - InElementwiseOperation{static_cast(reduce_total_length)}, - AccElementwiseOperation{static_cast(reduce_total_length)}); + auto argument_ptr = reduce.MakeArgumentPointer(i_inLengths, + i_inStrides, + i_outLengths, + i_outStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + out_index_dev.GetDeviceBuffer(), + in_elementwise_op, + acc_elementwise_op); if(!reduce.IsSupportedArgument(argument_ptr.get())) { @@ -352,7 +314,7 @@ int main(int argc, char* argv[]) auto invoker_ptr = reduce.MakeInvokerPointer(); - float avg_time = invoker_ptr->Run(argument_ptr.get(), args.nrepeat); + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + invariant_total_length * sizeof(OutDataType); @@ -362,16 +324,19 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name << std::endl; + bool pass = true; + if(args.do_verification) { out_dev.FromDevice(out.mData.data()); - ck::utils::check_err(out.mData, out_ref.mData); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); - if(NeedIndices) + if(OutputIndex) { - out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::check_err(out_indices.mData, out_indices_ref.mData); - ; + out_index_dev.FromDevice(out_indices.mData.data()); + pass = pass && ck::utils::check_err(out_indices.mData, out_indices_ref.mData); }; }; + + return (pass ? 0 : 1); } diff --git a/example/12_reduce/reduce_blockwise_two_call.cpp b/example/12_reduce/reduce_blockwise_two_call.cpp new file mode 100644 index 0000000000..e4823667a8 --- /dev/null +++ b/example/12_reduce/reduce_blockwise_two_call.cpp @@ -0,0 +1,301 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_reduce_multiblock.hpp" +#include "host_common_util.hpp" +#include "host_reduction.hpp" + +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InOutDataType = ck::half_t; +using InOutDataType = ck::half_t; +using AccDataType = float; + +constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::NORM2; +constexpr bool PropagateNan = true; +constexpr bool OutputIndex = false; + +using ReduceOperation = typename reduce_binary_operator::opType; +using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; +using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + +using PassThroughOp = tensor_operation::element_wise::PassThrough; + +using DeviceReduceInstance_1 = DeviceReduceMultiBlock; + +using DeviceReduceInstance_2 = DeviceReduceMultiBlock; + +static bool do_verify; +static int init_method; +static float alpha; +static float beta; +static bool time_kernel; + +int main(int argc, char* argv[]) +{ + // used by the device reduction + const std::vector reduceDims_1 = {4}; + const std::vector invariantDims_1 = {0, 1, 2, 3}; + + const std::vector reduceDims_2 = {3}; + const std::vector invariantDims_2 = {0, 1, 2}; + + // used by the host reduction + const std::vector reduceDims = {3, 4}; + const std::vector invariantDims = {0, 1, 2}; + + const std::vector inLengths_1 = {64, 320, 80, 4, 128}; + + // input lengths of the second reduction, which is also the output lengths of the first + // reduction + const std::vector inLengths_2 = {64, 320, 80, 4}; + + const std::vector outLengths = {64, 320, 80}; + + if(argc == 1) + { + do_verify = true; + init_method = 2; + time_kernel = true; + } + else if(argc == 4) + { + do_verify = static_cast(argv[1]); + init_method = atoi(argv[2]); + time_kernel = static_cast(atoi(argv[3])); + } + else + { + std::ostringstream ostr; + + ostr << "Wrong parameter! " << std::endl + << "Usage: " << argv[0] << "[verify 0/1] init_method time_kernel" << std::endl; + + throw std::runtime_error(ostr.str()); + }; + + alpha = 1.0f; + beta = 0.0f; + + Tensor in_1(inLengths_1); + + Tensor out_ref(outLengths); + Tensor in_2(inLengths_2); // also the output tensor of the first reduction + Tensor out(outLengths); + + auto inStrides_1 = in_1.mDesc.GetStrides(); + auto inStrides_2 = in_2.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + size_t invariant_total_length = out.mDesc.GetElementSize(); + size_t reduce_total_length = in_1.mDesc.GetElementSize() / invariant_total_length; + + std::size_t num_thread = 1; + + if(do_verify) + { + switch(init_method) + { + case 0: break; + case 1: + in_1.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in_1.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in_1.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, + num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + + DeviceMem in_1_dev(sizeof(InOutDataType) * in_1.mDesc.GetElementSpace()); + DeviceMem in_2_dev(sizeof(InOutDataType) * in_2.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(InOutDataType) * out.mDesc.GetElementSpace()); + + in_1_dev.ToDevice(in_1.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); + + if(do_verify) + { + ReductionHost + hostReduce(in_1.mDesc, out_ref.mDesc, invariantDims, reduceDims); + + hostReduce.Run(alpha, + in_1.mData.data(), + beta, + out_ref.mData.data(), + nullptr, + in_elementwise_op, + acc_elementwise_op); + }; + + std::vector i_inLengths_1; + std::vector i_inStrides_1; + std::vector i_inLengths_2; + std::vector i_inStrides_2; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths_1.assign(inLengths_1.begin(), inLengths_1.end()); + i_inStrides_1.assign(inStrides_1.begin(), inStrides_1.end()); + i_inLengths_2.assign(inLengths_2.begin(), inLengths_2.end()); + i_inStrides_2.assign(inStrides_2.begin(), inStrides_2.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); + + auto reduce_1 = DeviceReduceInstance_1{}; + + auto argument_ptr_1 = reduce_1.MakeArgumentPointer(i_inLengths_1, + i_inStrides_1, + i_inLengths_2, + i_inStrides_2, + reduceDims_1, + 1.0f, + 0.0f, + in_1_dev.GetDeviceBuffer(), + nullptr, + in_2_dev.GetDeviceBuffer(), + nullptr, + in_elementwise_op, + PassThroughOp{}); + + if(!reduce_1.IsSupportedArgument(argument_ptr_1.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_1 = reduce_1.MakeInvokerPointer(); + + auto reduce_2 = DeviceReduceInstance_2{}; + + auto argument_ptr_2 = reduce_2.MakeArgumentPointer(i_inLengths_2, + i_inStrides_2, + i_outLengths, + i_outStrides, + reduceDims_2, + alpha, + beta, + in_2_dev.GetDeviceBuffer(), + nullptr, + out_dev.GetDeviceBuffer(), + nullptr, + PassThroughOp{}, + acc_elementwise_op); + + if(!reduce_2.IsSupportedArgument(argument_ptr_2.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + }; + + auto invoker_ptr_2 = reduce_2.MakeInvokerPointer(); + + float avg_time_1 = invoker_ptr_1->Run(argument_ptr_1.get(), StreamConfig{nullptr, time_kernel}); + float avg_time_2 = invoker_ptr_2->Run(argument_ptr_2.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InOutDataType) + + invariant_total_length * sizeof(InOutDataType); + + float gb_per_sec = num_bytes / 1.E6 / (avg_time_1 + avg_time_2); + + std::cout << "Perf: " << avg_time_1 + avg_time_2 << " ms, " << gb_per_sec << " GB/s, " + << reduce_1.GetTypeString() << " => " << reduce_2.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verify) + { + out_dev.FromDevice(out.mData.data()); + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index 1fdeb4c585..db09c03321 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1 +1,3 @@ -add_example_executable(example_pool2d_fwd pool2d_fwd.cpp) +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) + diff --git a/example/13_pool2d_fwd/README.md b/example/13_pool2d_fwd/README.md index d9c829fb98..9b017734e9 100644 --- a/example/13_pool2d_fwd/README.md +++ b/example/13_pool2d_fwd/README.md @@ -1,12 +1,12 @@ -# Instructions for ```example_pool2d_fwd``` Example +# Instructions for ```example_pool2d_fwd``` Examples -## Run ```example_pool2d_fwd``` +## Run ```example_pool2d_fwd_fp16``` ```bash #arg1: verification (0=no, 1=yes) #arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) -#arg3: run kernel # of times (>1) +#arg3: time kernel (0=no, 1=yes) #arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx -./bin/example_pool2d_fwd 1 1 10 +./bin/example_pool2d_fwd_fp16 1 1 1 ``` Result @@ -14,9 +14,28 @@ Result in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} -Warm up +Warm up 1 time Start running 10 times... -Perf: 0.415453 ms, 1.37996 TFlops, 749.726 GB/s -error: 0 -max_diff: 0, 1, 1 +Perf: 0.397436 ms, 1.44252 TFlops, 783.713 GB/s +``` + +## Run ```example_pool2d_fwd_fp32``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg3: time kernel (0=no, 1=yes) +#arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, RightPx +./bin/example_pool2d_fwd_fp32 1 1 1 +``` + + +Result +``` +./bin/example_pool2d_fwd_fp32 1 1 1 +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +out_n_c_ho_wo: dim 4, lengths {128, 192, 36, 36}, strides {248832, 1, 6912, 192} +launch_and_time_kernel: grid_dim {124416, 1, 1}, block_dim {64, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.01823 ms, 0.563045 TFlops, 611.8 GB/s ``` diff --git a/example/13_pool2d_fwd/pool2d_fwd.cpp b/example/13_pool2d_fwd/pool2d_fwd.cpp deleted file mode 100644 index 9def6c24fe..0000000000 --- a/example/13_pool2d_fwd/pool2d_fwd.cpp +++ /dev/null @@ -1,315 +0,0 @@ -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_reduce_util.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "reduction_operator.hpp" -#include "device_pool2d_fwd_nhwc_nhwc.hpp" - -using InDataType = ck::half_t; -using OutDataType = ck::half_t; -using AccDataType = float; - -using InLayout = ck::tensor_layout::convolution::NHWC; -using OutLayout = ck::tensor_layout::convolution::NHWC; - -#if 1 -static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; -#else -static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; -#endif - -static constexpr bool NeedIndices = false; -static constexpr bool PropagateNan = false; - -using DevicePoolFwdInstance = - ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< - InDataType, // InDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - ReduceOpId, - NeedIndices, - 64, // BlockSize - 64, // ReduceMThreadClusterSize - 1, // ReduceKThreadClusterSize - 4, // ReduceMThreadSliceSize - 1, // ReduceKThreadSliceSize - 4>; // InSrcOutDstVectorSize - -template -static void pool_host_verify(const Tensor& in, - Tensor& out, - Tensor& out_indices, - const std::array& window_spatial_lengths, - const std::array& window_strides, - const std::array& in_left_pads, - const std::array& /*in_right_pads*/) -{ - using namespace ck::host_reduce; - - const int divider = window_spatial_lengths[0] * window_spatial_lengths[1]; - - const auto PreUnaryOp = PreUnaryOpFn(divider); - const auto PosUnaryOp = PosUnaryOpFn(divider); - - if constexpr(!NeedIndices) - { - auto opReduce = ReduceOpFn(); - - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOpZeroVal(); - - for(int y = 0; y < window_spatial_lengths[0]; ++y) - { - int hi = ho * window_strides[0] + y - in_left_pads[0]; - for(int x = 0; x < window_spatial_lengths[1]; ++x) - { - int wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - - PreUnaryOp(currVal); - - binop_with_nan_check(opReduce, accuVal, currVal); - } - } - } - - PosUnaryOp(accuVal); - - out(n, c, ho, wo) = accuVal; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - } - else - { - auto opReduce = ReduceOpFn2(); - - auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { - auto accuVal = ReduceOpZeroVal(); - int accuIndex = 0; - - for(int y = 0; y < window_spatial_lengths[0]; ++y) - { - int hi = ho * window_strides[0] + y - in_left_pads[0]; - for(int x = 0; x < window_spatial_lengths[1]; ++x) - { - int wi = wo * window_strides[1] + x - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - AccDataType currVal = static_cast(in(n, c, hi, wi)); - int currIndex = y * window_spatial_lengths[1] + x; - - PreUnaryOp(currVal); - - binop_with_nan_check2( - opReduce, accuVal, currVal, accuIndex, currIndex); - } - } - } - - PosUnaryOp(accuVal); - - out(n, c, ho, wo) = accuVal; - out_indices(n, c, ho, wo) = accuIndex; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); - }; -} - -int main(int argc, char* argv[]) -{ - using namespace ck::host_reduce; - - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // Pool shape - ck::index_t N = 128; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t window_stride_h = 2; - ck::index_t window_stride_w = 2; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 16) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - window_stride_h = std::stoi(argv[10]); - window_stride_w = std::stoi(argv[11]); - in_left_pad_h = std::stoi(argv[12]); - in_left_pad_w = std::stoi(argv[13]); - in_right_pad_h = std::stoi(argv[14]); - in_right_pad_w = std::stoi(argv[15]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); - } - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; - - const std::array window_spatial_lengths{{Y, X}}; - const std::array window_strides{{window_stride_h, window_stride_w}}; - const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; - const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; - - // tensor layout - auto f_host_tensor_descriptor = - [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { - if constexpr(ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); - Tensor out_indices_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); - Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); - Tensor out_indices_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); break; - case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; - default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace()); - DeviceMem out_indices_device_buf(sizeof(int) * - out_indices_n_c_ho_wo_device.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - - auto pool = DevicePoolFwdInstance{}; - auto invoker_ptr = pool.MakeInvokerPointer(); - auto argument_ptr = - pool.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - static_cast(out_indices_device_buf.GetDeviceBuffer()), - N, - C, - std::array{{Hi, Wi}}, - std::array{{Y, X}}, - std::array{{Ho, Wo}}, - window_strides, - input_left_pads, - input_right_pads); - - if(!pool.IsSupportedArgument(argument_ptr.get())) - { - throw std::runtime_error("wrong! device_op with the specified compilation parameters does " - "not support this problem"); - } - - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - - std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; - - std::size_t num_btype = - sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - if(do_verification) - { - pool_host_verify(in_n_c_hi_wi, - out_n_c_ho_wo_host, - out_indices_n_c_ho_wo_host, - window_spatial_lengths, - window_strides, - input_left_pads, - input_right_pads); - - out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); - - ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); - - if constexpr(NeedIndices) - { - out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); - - // ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, - // out_indices_n_c_ho_wo_host.mData);; - }; - } -} diff --git a/example/13_pool2d_fwd/pool2d_fwd_common.hpp b/example/13_pool2d_fwd/pool2d_fwd_common.hpp new file mode 100644 index 0000000000..436bbcd485 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_common.hpp @@ -0,0 +1,280 @@ +#pragma once + +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" +#include "reduction_functions_accumulate.hpp" + +#include "device_pool2d_fwd_nhwc_nhwc.hpp" + +template +static void pool_host_verify(const Tensor& in, + Tensor& out, + Tensor& out_indices, + const std::array& window_spatial_lengths, + const std::array& window_strides, + const std::array& in_left_pads, + const std::array& /*in_right_pads*/) +{ + const int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + using ReduceOperation = typename ck::reduce_binary_operator::opType; + + auto elementwise_ops = + ck::reduce_unary_operator::GetElementwiseOperator(reduceLength); + + auto in_elementwise_op = std::get<0>(elementwise_ops); + auto acc_elementwise_op = std::get<1>(elementwise_ops); + + if constexpr(!OutputIndex) + { + using Accumulation = + ck::detail::AccumulateWithNanCheck; + + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < static_cast(in.mDesc.GetLengths()[2]) && + wi >= 0 && wi < static_cast(in.mDesc.GetLengths()[3])) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; + auto f_nchw = [&](auto n, auto c, auto ho, auto wo) { + auto accuVal = ReduceOperation::template GetIdentityValue(); + IndexDataType accuIndex = 0; + + for(ck::index_t y = 0; y < window_spatial_lengths[0]; ++y) + { + ck::index_t hi = ho * window_strides[0] + y - in_left_pads[0]; + for(ck::index_t x = 0; x < window_spatial_lengths[1]; ++x) + { + ck::index_t wi = wo * window_strides[1] + x - in_left_pads[1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + AccDataType currVal = static_cast(in(n, c, hi, wi)); + IndexDataType currIndex = y * window_spatial_lengths[1] + x; + + in_elementwise_op(currVal, currVal); + + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); + } + } + } + + acc_elementwise_op(accuVal, accuVal); + + out(n, c, ho, wo) = accuVal; + out_indices(n, c, ho, wo) = accuIndex; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + }; +} + +template +bool pool_test(bool do_verification, + int init_method, + bool time_kernel, + ck::index_t N, + ck::index_t C, + ck::index_t Y, + ck::index_t X, + ck::index_t Hi, + ck::index_t Wi, + ck::index_t window_stride_h, + ck::index_t window_stride_w, + ck::index_t in_left_pad_h, + ck::index_t in_left_pad_w, + ck::index_t in_right_pad_h, + ck::index_t in_right_pad_w) +{ + using DevicePoolFwdInstance = + ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< + InDataType, // InDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + ReduceOpId, + OutputIndex, + 64, // BlockSize + 64, // ReduceMThreadClusterSize + 1, // ReduceKThreadClusterSize + 4, // ReduceMThreadSliceSize + 1, // ReduceKThreadSliceSize + 4>; // InSrcOutDstVectorSize + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1; + + const std::array window_spatial_lengths{{Y, X}}; + const std::array window_strides{{window_stride_h, window_stride_w}}; + const std::array input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::array input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = + [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { + if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor out_n_c_ho_wo_host(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_host( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_n_c_ho_wo_device(f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + Tensor out_indices_n_c_ho_wo_device( + f_host_tensor_descriptor(N, C, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "out_n_c_ho_wo: " << out_n_c_ho_wo_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); break; + case 2: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; + default: in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_c_ho_wo_device.mDesc.GetElementSpace()); + DeviceMem out_indices_device_buf(sizeof(IndexDataType) * + out_indices_n_c_ho_wo_device.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + + auto pool = DevicePoolFwdInstance{}; + auto invoker_ptr = pool.MakeInvokerPointer(); + auto argument_ptr = pool.MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + static_cast(out_indices_device_buf.GetDeviceBuffer()), + N, + C, + std::array{{Hi, Wi}}, + std::array{{Y, X}}, + std::array{{Ho, Wo}}, + window_strides, + input_left_pads, + input_right_pads); + + if(!pool.IsSupportedArgument(argument_ptr.get())) + { + throw std::runtime_error("wrong! device_op with the specified compilation parameters does " + "not support this problem"); + } + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * N * C * Ho * Wo * Y * X; + + std::size_t num_btype = + sizeof(InDataType) * (N * C * Hi * Wi) + sizeof(OutDataType) * (N * C * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + pool_host_verify(in_n_c_hi_wi, + out_n_c_ho_wo_host, + out_indices_n_c_ho_wo_host, + window_spatial_lengths, + window_strides, + input_left_pads, + input_right_pads); + + out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); + + if constexpr(OutputIndex) + { + out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); + + pass = pass && ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, + out_indices_n_c_ho_wo_host.mData); + }; + } + + return (pass); +}; diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp new file mode 100644 index 0000000000..74507fdfb3 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_fp16.cpp @@ -0,0 +1,114 @@ +#include +#include + +#include "config.hpp" +#include "tensor_layout.hpp" +#include "reduction_enums.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp new file mode 100644 index 0000000000..7ca5b1aab7 --- /dev/null +++ b/example/13_pool2d_fwd/pool2d_fwd_fp32.cpp @@ -0,0 +1,114 @@ +#include +#include + +#include "config.hpp" +#include "tensor_layout.hpp" +#include "reduction_enums.hpp" + +#include "pool2d_fwd_common.hpp" + +using InDataType = float; +using OutDataType = float; +using AccDataType = float; + +using IndexDataType = int32_t; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using OutLayout = ck::tensor_layout::convolution::NHWC; + +#if 1 +static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX; +#else +static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG; +#endif + +static constexpr bool OutputIndex = false; +static constexpr bool PropagateNan = false; + +int main(int argc, char* argv[]) +{ + bool do_verification; + int init_method; + bool time_kernel; + + // Pool shape + ck::index_t N = 128; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t window_stride_h = 2; + ck::index_t window_stride_w = 2; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 1) + { + do_verification = true; + init_method = 1; + time_kernel = true; + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + } + else if(argc == 16) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = static_cast(std::stoi(argv[3])); + + N = std::stoi(argv[4]); + C = std::stoi(argv[5]); + Y = std::stoi(argv[6]); + X = std::stoi(argv[7]); + Hi = std::stoi(argv[8]); + Wi = std::stoi(argv[9]); + window_stride_h = std::stoi(argv[10]); + window_stride_w = std::stoi(argv[11]); + in_left_pad_h = std::stoi(argv[12]); + in_left_pad_w = std::stoi(argv[13]); + in_right_pad_h = std::stoi(argv[14]); + in_right_pad_w = std::stoi(argv[15]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + bool pass = pool_test(do_verification, + init_method, + time_kernel, + N, + C, + Y, + X, + Hi, + Wi, + window_stride_h, + window_stride_w, + in_left_pad_h, + in_left_pad_w, + in_right_pad_h, + in_right_pad_w); + + return (pass ? 0 : 1); +} diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index 324dc35d3f..a42df2b7f0 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -100,14 +100,19 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; // GEMM shape ck::index_t M = 3840; @@ -125,13 +130,13 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 10) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -145,7 +150,7 @@ int main(int argc, char* argv[]) { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); exit(0); } @@ -219,7 +224,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = @@ -244,7 +249,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 4e9bdbb2f5..503c87e138 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -56,29 +56,29 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + ReferenceGemm; int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); exit(0); } - int group_count = 4; + int group_count = rand() % 16 + 1; // GEMM shape std::vector gemm_shapes; @@ -131,7 +131,7 @@ int main(int argc, char* argv[]) std::size_t flop = 0, num_btype = 0; - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { a_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); @@ -168,7 +168,7 @@ int main(int argc, char* argv[]) } } - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { a_tensors_device.emplace_back( std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpace())); @@ -189,12 +189,17 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; - // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); + + // do GEMM auto argument = gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer()); + if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( @@ -202,7 +207,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, nrepeat); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -211,9 +216,10 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; + bool pass = true; if(do_verification) { - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); auto ref_gemm = ReferenceGemmInstance{}; @@ -227,9 +233,9 @@ int main(int argc, char* argv[]) c_element_op); ref_invoker.Run(ref_argument); - ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); + pass &= ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); } } - return 0; + return pass ? 0 : 1; } diff --git a/example/16_gemm_reduce/CMakeLists.txt b/example/16_gemm_reduce/CMakeLists.txt index 08d37b34a6..90ff589794 100644 --- a/example/16_gemm_reduce/CMakeLists.txt +++ b/example/16_gemm_reduce/CMakeLists.txt @@ -1 +1,2 @@ -add_example_executable(example_gemm_reduce_xdl_fp16 gemm_reduce_xdl_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp) diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp deleted file mode 100644 index 90064ae584..0000000000 --- a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp +++ /dev/null @@ -1,273 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "device_gemm_reduce_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reduction_operator.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" -#include "reduction_operator.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using ADataType = F16; -using BDataType = F16; -using CDataType = F16; -using DDataType = F32; - -using ALayout = ck::tensor_layout::gemm::RowMajor; -using BLayout = ck::tensor_layout::gemm::ColumnMajor; -using CLayout = ck::tensor_layout::gemm::RowMajor; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using D0ReduceOp = ck::reduce::Add; -using D1ReduceOp = ck::reduce::Add; -using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare; - -static constexpr auto GemmSpecialization = - ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -int main(int argc, char* argv[]) -{ - bool do_verification = 1; - int init_method = 1; - int nrepeat = 5; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; - - if(argc == 1) - { - // do nothing - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 10) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); - } - }; - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_host_result( - HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_host_result( - HostTensorDescriptor(std::vector({static_cast(M)}))); - - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor d0_m_device_result( - HostTensorDescriptor(std::vector({static_cast(M)}))); - Tensor d1_m_device_result( - HostTensorDescriptor(std::vector({static_cast(M)}))); - - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; - std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); - DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); - - a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto d1_element_op = D1ElementOp{}; - - // do GEMM - auto gemm = DeviceGemmReduceInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - StrideB, - StrideC, - a_element_op, - b_element_op, - c_element_op, - d1_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - // warm up - invoker.Run(argument); - - // timing - float total_time = 0; - - for(int i = 0; i < nrepeat; ++i) - { - // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); - - KernelTimer timer; - - timer.Start(); - - invoker.Run(argument); - - timer.End(); - - total_time += timer.GetElapsedTime(); - } - - float ave_time = total_time / nrepeat; - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - if(do_verification) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - d0_device_buf.FromDevice(d0_m_device_result.mData.data()); - d1_device_buf.FromDevice(d1_m_device_result.mData.data()); - - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); - - ref_invoker.Run(ref_argument); - - auto d0_reduce_op = D0ReduceOp{}; - auto d1_reduce_op = D1ReduceOp{}; - - for(int m = 0; m < M; ++m) - { - float d0_acc = d0_reduce_op.GetReductionZeroVal(); - float d1_acc = d1_reduce_op.GetReductionZeroVal(); - - for(int n = 0; n < N; ++n) - { - float d0_val = ck::type_convert(c_m_n_host_result(m, n)); - float d1_val; - - d1_element_op(d1_val, d0_val); - d0_reduce_op(d0_acc, d0_val); - d1_reduce_op(d1_acc, d1_val); - } - - d0_m_host_result(m) = ck::type_convert(d0_acc); - d1_m_host_result(m) = ck::type_convert(d1_acc); - } - - check_error(c_m_n_host_result, c_m_n_device_result); - check_error(d0_m_host_result, d0_m_device_result); - check_error(d1_m_host_result, d1_m_device_result); - } - - return 0; -} diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp new file mode 100644 index 0000000000..92113e3c41 --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using DDataType = F64; +using DPtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using DsReduceOp = ck::Tuple; +using DsElementOp = ck::Tuple; +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, " + << gemm_gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d_m: " << d_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto ds_element_op = DsElementOp{}; + auto p_ds_global = ck::make_tuple(static_cast(d_device_buf.GetDeviceBuffer())); + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + p_ds_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + ds_element_op, + ds_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // [CAUSION]: launch_and_time_kernel will not initialize D. + // If we evaluate kernel multiple time but without initialize D. Verification will fail + d_device_buf.SetValue(ck::NumericLimits::Lowest()); + invoker.Run(argument, StreamConfig{nullptr, false}); + + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d_device_buf.FromDevice(d_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}]; + + for(int m = 0; m < M; ++m) + { + ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType curr_val = + ck::type_convert(c_m_n_host_result(m, n)); + d_reduce_op(d_acc, curr_val); + }; + + d_m_host_result(m) = d_acc; + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d_m_device_result.mData, + d_m_host_result.mData, + "Error: Incorrect results d", + 1e-3, + 1e-3); + } + + if(time_kernel) + { + float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf( + gemm_reduceMax_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp new file mode 100644 index 0000000000..018645e066 --- /dev/null +++ b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp @@ -0,0 +1,305 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "reduction_operator.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using D0ReduceOp = ck::reduce::Add; +using D1ReduceOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOps = ck::Tuple; +using DxsOutElementOps = ck::Tuple; + +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor d0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{N, N}; + + // do GEMM + auto gemm = DeviceGemmReduceInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + invoker.Run(argument, StreamConfig{nullptr, false}); + bool pass = true; + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; + + for(int m = 0; m < M; ++m) + { + auto d0_acc = d0_reduce_op.GetIdentityValue(); + auto d1_acc = d1_reduce_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; + + dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); + dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); + dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); + d0_m_host_result(m) = ck::type_convert(d0_acc); + d1_m_host_result(m) = ck::type_convert(d1_acc); + } + + pass = ck::utils::check_err(c_m_n_device_result.mData, + c_m_n_host_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d0_m_device_result.mData, + d0_m_host_result.mData, + "Error: Incorrect results d0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_m_device_result.mData, + d1_m_host_result.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-5); + } + + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf(ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/17_convnd_bwd_data_xdl/CMakeLists.txt b/example/17_convnd_bwd_data_xdl/CMakeLists.txt index 0ed906f8f7..963f311703 100644 --- a/example/17_convnd_bwd_data_xdl/CMakeLists.txt +++ b/example/17_convnd_bwd_data_xdl/CMakeLists.txt @@ -1,2 +1,2 @@ add_example_executable(example_convnd_bwd_data_xdl convnd_bwd_data_xdl.cpp) -target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_fwd_util) +target_link_libraries(example_convnd_bwd_data_xdl PRIVATE conv_util) diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 962627ce90..0383197358 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -6,7 +6,7 @@ #include #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -87,7 +87,7 @@ void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" - << "arg3: run kernel # of times (>1)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" << "arg4: N spatial dimensions (default 2)\n" << "Following arguments (depending on number of spatial dims):\n" << " N, K, C, \n" @@ -105,40 +105,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) ck::utils::conv::ConvParams params; int arg_idx = 5; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -165,25 +165,25 @@ DeviceConvBwdDataBasePtr get_conv_instance(int num_dim_spatial) int main(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; int num_dim_spatial = 2; ck::utils::conv::ConvParams params; - params.C = 128; + params.C_ = 128; if(argc == 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc > 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); num_dim_spatial = std::stoi(argv[4]); // check args number int conv_args = 3 + num_dim_spatial * 6; @@ -202,21 +202,21 @@ int main(int argc, char* argv[]) exit(1); } - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -263,16 +263,16 @@ int main(int argc, char* argv[]) conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -284,16 +284,16 @@ int main(int argc, char* argv[]) "not support this Conv problem"); } - float ave_time = invoker->Run(argument.get(), nrepeat); + float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = ck::utils::conv::get_flops( - params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); std::size_t num_btype = ck::utils::conv::get_btype( - params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -310,10 +310,10 @@ int main(int argc, char* argv[]) auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, wei_k_c_y_x, out_n_k_ho_wo, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -322,29 +322,30 @@ int main(int argc, char* argv[]) in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); + return ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData) + ? 0 + : 1; }; switch(num_dim_spatial) { case 3: { auto ref_conv = ReferenceConvBwdDataInstance<3>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 2: { auto ref_conv = ReferenceConvBwdDataInstance<2>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } case 1: { auto ref_conv = ReferenceConvBwdDataInstance<1>(); - verify_f(ref_conv); - break; + return verify_f(ref_conv); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); } } } + return 0; } diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index eb18655d1b..de584ad7e8 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -4,6 +4,7 @@ #include #include #include +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -24,10 +25,12 @@ using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; -using ADataType = F16; -using BDataType = F16; -using CDataType = F16; -using DDataType = F32; +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -36,20 +39,29 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough; -using D0ReduceOp = ck::reduce::Add; -using D1ReduceOp = ck::reduce::Add; -using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare; +using D0ReduceOp = ck::reduce::Add; +using D1ReduceOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOps = ck::Tuple; +using DxsOutElementOps = ck::Tuple; + +using DGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmSpecialization = ck::tensor_operation::device::GemmSpecialization::Default; // clang-format off using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| -//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, D1ElementOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: @@ -57,18 +69,18 @@ using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: int main(int argc, char* argv[]) { - bool do_verification = 1; + bool do_verification = true; int init_method = 1; - int nrepeat = 5; + bool time_kernel = false; // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; + ck::index_t M = 2048; + ck::index_t N = 1920; + ck::index_t K = 2048; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 2048; + ck::index_t StrideB = 2048; + ck::index_t StrideC = 1920; ck::index_t BatchCount = 4; @@ -80,13 +92,13 @@ int main(int argc, char* argv[]) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); } else if(argc == 11) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + time_kernel = std::stoi(argv[3]); M = std::stoi(argv[4]); N = std::stoi(argv[5]); @@ -96,13 +108,13 @@ int main(int argc, char* argv[]) StrideB = std::stoi(argv[8]); StrideC = std::stoi(argv[9]); - BatchCount = std::stoi(argv[9]); + BatchCount = std::stoi(argv[10]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); exit(0); } @@ -169,12 +181,11 @@ int main(int argc, char* argv[]) a_device_buf.ToDevice(a_g_m_k.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data()); - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - auto d0_reduce_op = D0ReduceOp{}; - auto d1_reduce_op = D1ReduceOp{}; - auto d1_element_op = D1ElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); // do GEMM auto batched_gemm = DeviceBatchedGemmReduceInstance{}; @@ -183,8 +194,7 @@ int main(int argc, char* argv[]) batched_gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer()), + dxs_global, M, N, K, @@ -194,7 +204,8 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, c_element_op, - d1_element_op, + DxsInElementOps{}, + DxsOutElementOps{}, BatchCount); if(!batched_gemm.IsSupportedArgument(argument)) @@ -204,30 +215,13 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - // warm up - invoker.Run(argument); + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); - // timing - float total_time = 0; - - for(int i = 0; i < nrepeat; ++i) - { - // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); - - KernelTimer timer; - - timer.Start(); - - invoker.Run(argument); - - timer.End(); - - total_time += timer.GetElapsedTime(); - } - - float ave_time = total_time / nrepeat; + // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result + // will not be correct. need to set time_kernel = false for correctness test + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * BatchCount * M * N * K; std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + @@ -241,6 +235,7 @@ int main(int argc, char* argv[]) std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << batched_gemm.GetTypeString() << std::endl; + bool pass = true; if(do_verification) { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); @@ -255,19 +250,25 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); + auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; + for(int batch = 0; batch < BatchCount; ++batch) { for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetReductionZeroVal(); - float d1_acc = d1_reduce_op.GetReductionZeroVal(); + auto d0_acc = d0_reduce_op.GetIdentityValue(); + auto d1_acc = d1_reduce_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float d0_val = ck::type_convert(c_g_m_n_host_result(m, n)); - float d1_val; + auto c_val = + ck::type_convert(c_g_m_n_host_result(batch, m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; - d1_element_op(d1_val, d0_val); + UnaryIdenticElementOp{}(d0_val, c_val); + UnarySquareElementOp{}(d1_val, c_val); d0_reduce_op(d0_acc, d0_val); d1_reduce_op(d1_acc, d1_val); } @@ -277,10 +278,20 @@ int main(int argc, char* argv[]) } } - check_error(c_g_m_n_host_result, c_g_m_n_device_result); - check_error(d0_g_m_host_result, d0_g_m_device_result); - check_error(d1_g_m_host_result, d1_g_m_device_result); + pass = ck::utils::check_err(c_g_m_n_host_result.mData, + c_g_m_n_device_result.mData, + "Error: Incorrect results c") && + ck::utils::check_err(d0_g_m_device_result.mData, + d0_g_m_host_result.mData, + "Error: Incorrect results! D0", + 1e-4, + 1e-5) && + ck::utils::check_err(d1_g_m_device_result.mData, + d1_g_m_host_result.mData, + "Error: Incorrect results! D1", + 1e-3, + 1e-5); } - return 0; + return pass ? 0 : 1; } diff --git a/example/19_binary_elementwise/CMakeLists.txt b/example/19_binary_elementwise/CMakeLists.txt new file mode 100644 index 0000000000..39646e0ab5 --- /dev/null +++ b/example/19_binary_elementwise/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp) +add_example_executable(example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp) +add_example_executable(example_elementwise_add_1d elementwise_add_1d.cpp) +add_example_executable(example_elementwise_add_4d elementwise_add_4d.cpp) \ No newline at end of file diff --git a/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp new file mode 100644 index 0000000000..587882ed9c --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp @@ -0,0 +1,164 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_broadcast2D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, int N, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + ComputeDataType Amn = ck::type_convert(A(m, n)); + ComputeDataType Cmn = 0; + if constexpr(broadcastDim == 0) + { + ComputeDataType Bn = ck::type_convert(B(n)); + functor(Cmn, Amn, Bn); + } + else + { + ComputeDataType Bm = ck::type_convert(B(m)); + functor(Cmn, Amn, Bm); + } + C(m, n) = ck::type_convert(Cmn); + } + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t Stride = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride) { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + }; + + Tensor a_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + Tensor b_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + a_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_n_device_buf(sizeof(ABDataType) * a_m_n.mDesc.GetElementSpace()); + DeviceMem b_n_device_buf(sizeof(ABDataType) * b_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + + a_m_n_device_buf.ToDevice(a_m_n.mData.data()); + b_n_device_buf.ToDevice(b_n.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(), + b_n_device_buf.GetDeviceBuffer(), + c_m_n_device_buf.GetDeviceBuffer(), + {M, N}, + {Stride, 1}, + {0, 1}, // broadcast in first dimension + {Stride, 1}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); + Tensor host_c_m_n(f_host_tensor_descriptor2d(M, N, Stride)); + + host_broadcast2D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add, + 0>(host_c_m_n, a_m_n, b_n, M, N, Add{}); + + pass &= ck::utils::check_err( + c_m_n.mData, host_c_m_n.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp new file mode 100644 index 0000000000..e03f3fa76e --- /dev/null +++ b/example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp @@ -0,0 +1,123 @@ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_broadcast3D_am_bmnk(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t m = 0; m < shape[0]; ++m) + for(std::size_t n = 0; n < shape[1]; ++n) + for(std::size_t k = 0; k < shape[2]; ++k) + { + ComputeDataType a_val = ck::type_convert(A(m)); + ComputeDataType b_val = ck::type_convert(B(m, n, k)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(m, n, k) = ck::type_convert(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector mnk = {4, 16, 32}; + ck::index_t M = mnk[0]; + + Tensor a_m({M}); + Tensor b_m_n_k(mnk); + Tensor c_m_n_k(mnk); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_n_k_device_buf(sizeof(ABDataType) * b_m_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_k_device_buf(sizeof(CDataType) * c_m_n_k.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + a_m_device_buf.GetDeviceBuffer(), + b_m_n_k_device_buf.GetDeviceBuffer(), + c_m_n_k_device_buf.GetDeviceBuffer(), + std::vector{mnk.begin(), mnk.end()}, + {1, 0, 0}, // broadcast A on second and third dimension + std::vector{b_m_n_k.mDesc.GetStrides().begin(), + b_m_n_k.mDesc.GetStrides().end()}, + std::vector{c_m_n_k.mDesc.GetStrides().begin(), + c_m_n_k.mDesc.GetStrides().end()}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_n_k_device_buf.FromDevice(c_m_n_k.mData.data()); + Tensor host_c_m_n_k(mnk); + + host_broadcast3D_am_bmnk, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m_n_k, a_m, b_m_n_k, mnk, Add{}); + + pass &= ck::utils::check_err( + c_m_n_k.mData, host_c_m_n_k.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_1d.cpp b/example/19_binary_elementwise/elementwise_add_1d.cpp new file mode 100644 index 0000000000..c96e9616d7 --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_1d.cpp @@ -0,0 +1,144 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_elementwise1D( + HostTensorC& C, const HostTensorA& A, const HostTensorB& B, int M, Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(int m = 0; m < M; ++m) + { + ComputeDataType Am = ck::type_convert(A(m)); + ComputeDataType Bm = ck::type_convert(B(m)); + ComputeDataType Cm = 0; + functor(Cm, Am, Bm); + C(m) = ck::type_convert(Cm); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + ck::index_t M = 1024; + + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + Tensor a_m(f_host_tensor_descriptor1d(M, 1)); + Tensor b_m(f_host_tensor_descriptor1d(M, 1)); + Tensor c_m(f_host_tensor_descriptor1d(M, 1)); + + a_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_m.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_m_device_buf(sizeof(ABDataType) * a_m.mDesc.GetElementSpace()); + DeviceMem b_m_device_buf(sizeof(ABDataType) * b_m.mDesc.GetElementSpace()); + DeviceMem c_m_device_buf(sizeof(CDataType) * c_m.mDesc.GetElementSpace()); + + a_m_device_buf.ToDevice(a_m.mData.data()); + b_m_device_buf.ToDevice(b_m.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(), + b_m_device_buf.GetDeviceBuffer(), + c_m_device_buf.GetDeviceBuffer(), + {M}, + {1}, + {1}, + {1}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_m_device_buf.FromDevice(c_m.mData.data()); + Tensor host_c_m(f_host_tensor_descriptor1d(M, 1)); + + host_elementwise1D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c_m, a_m, b_m, M, Add{}); + + pass &= ck::utils::check_err( + c_m.mData, host_c_m.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/19_binary_elementwise/elementwise_add_4d.cpp b/example/19_binary_elementwise/elementwise_add_4d.cpp new file mode 100644 index 0000000000..13345ec11f --- /dev/null +++ b/example/19_binary_elementwise/elementwise_add_4d.cpp @@ -0,0 +1,146 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +#include "device_tensor.hpp" +#include "binary_element_wise_operation.hpp" +#include "device_binary_elementwise.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ABDataType = F16; +using CDataType = F16; +using EltwiseComputeDataType = F32; + +using Add = ck::tensor_operation::element_wise::Add; + +using DeviceElementwiseAddInstance = + ck::tensor_operation::device::DeviceBinaryElementwise; + +template +void host_elementwise4D(HostTensorC& C, + const HostTensorA& A, + const HostTensorB& B, + const std::vector& shape, + Functor functor) +{ + using ctype = ck::remove_reference_t; + + for(std::size_t n = 0; n < shape[0]; ++n) + for(std::size_t c = 0; c < shape[1]; ++c) + for(std::size_t h = 0; h < shape[2]; ++h) + for(std::size_t w = 0; w < shape[3]; ++w) + { + ComputeDataType a_val = ck::type_convert(A(n, c, h, w)); + ComputeDataType b_val = ck::type_convert(B(n, c, h, w)); + ComputeDataType c_val = 0; + functor(c_val, a_val, b_val); + C(n, c, h, w) = ck::type_convert(c_val); + } +} + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + std::vector nchw = {4, 16, 32, 32}; + + Tensor a(nchw); + Tensor b(nchw); + Tensor c(nchw); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ABDataType) * a.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(ABDataType) * b.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a.mData.data()); + b_device_buf.ToDevice(b.mData.data()); + + auto broadcastAdd = DeviceElementwiseAddInstance{}; + auto argument = broadcastAdd.MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + std::vector{nchw.begin(), nchw.end()}, + std::vector{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()}, + std::vector{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()}, + std::vector{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()}, + Add{}); + + if(!broadcastAdd.IsSupportedArgument(argument.get())) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "DeviceBinaryElementwise instance, exiting!"); + }; + + auto broadcastAdd_invoker_ptr = broadcastAdd.MakeInvokerPointer(); + float ave_time = + broadcastAdd_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::cout << "Perf: " << ave_time << " ms" << std::endl; + + bool pass = true; + if(do_verification) + { + c_device_buf.FromDevice(c.mData.data()); + Tensor host_c(nchw); + + host_elementwise4D, + Tensor, + Tensor, + EltwiseComputeDataType, + Add>(host_c, a, b, nchw, Add{}); + + pass &= + ck::utils::check_err(c.mData, host_c.mData, "Error: Incorrect results c", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/20_convnd_bwd_weight_xdl/CMakeLists.txt b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt new file mode 100644 index 0000000000..66fdef625a --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/CMakeLists.txt @@ -0,0 +1,4 @@ +add_example_executable(example_convnd_bwd_weight_xdl convnd_bwd_weight_xdl.cpp) +add_example_executable(example_convnd_bwd_weight_xdl_bf16_splitk convnd_bwd_weight_xdl_bf16_splitk.cpp) +target_link_libraries(example_convnd_bwd_weight_xdl PRIVATE conv_util) +target_link_libraries(example_convnd_bwd_weight_xdl_bf16_splitk PRIVATE conv_util) \ No newline at end of file diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp new file mode 100644 index 0000000000..f917c2c3ac --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -0,0 +1,385 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "conv_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k \n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 1; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + float ave_time = 0.f; + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + return verify_f(ref_conv); + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + return verify_f(ref_conv); + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + return verify_f(ref_conv); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp new file mode 100644 index 0000000000..43f0cdb7ec --- /dev/null +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl_bf16_splitk.cpp @@ -0,0 +1,427 @@ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "conv_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "element_wise_operation.hpp" +#include "device_unary_elementwise.hpp" +#include "device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_backward_weight.hpp" + +using InDataType = ck::bhalf_t; +using WeiDataType = ck::bhalf_t; +using OutDataType = ck::bhalf_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using UnaryTypeConvert = ck::tensor_operation::element_wise::UnaryTypeConvert; + +using DeviceUnaryElementwiseTypeConvertInstance = ck::tensor_operation::device:: + DeviceUnaryElementwise; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +using DeviceConvBwdWeightBasePtr = + ck::tensor_operation::device::DeviceConvBwdWeightPtr; + +// clang-format off +template +using DeviceConvndBwdWeightInstance_bf16_splitk = ck::tensor_operation::device:: + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + InDataType, // InDataType + AccDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 2, // NXdlPerWave + S<1, 4, 16, 4>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<0, 3, 1, 2>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 2, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 4, 16, 4>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<0, 3, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1, 3>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 4>, // CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +template +using ReferenceConvBwdWeightInstance = + ck::tensor_operation::host::ReferenceConvBwdWeight; + +template +void host_elementwise(HostTensorB& B, + const HostTensorA& A, + const std::vector& shape, + Functor functor) +{ + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + std::cout << __LINE__ << ":" << tensor_size << ", " << A.mData[0] << std::endl; + for(std::size_t n = 0; n < tensor_size; ++n) + { + B.mData[n] = functor(A.mData[n]); + } +} + +void print_use_msg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=random value, 2= init to 1 )\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4: is show log (0=no, 1=yes)\n" + << "arg5: split-k : in this example split-k must be larger than 1\n" + << "arg6: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + ck::utils::conv::ConvParams params; + int arg_idx = 7; + + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads_.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +DeviceConvBwdWeightBasePtr get_conv_instance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + int num_dim_spatial = 2; + int do_log = 0; + int split_k = 2; + + ck::utils::conv::ConvParams params; + params.C_ = 128; + + if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + } + else if(argc > 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + do_log = std::stoi(argv[4]); + split_k = std::stoi(argv[5]); + num_dim_spatial = std::stoi(argv[6]); + // check args number + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 7; + if(cmdline_nargs != argc) + { + print_use_msg(); + exit(1); + } + + params = parse_conv_params(num_dim_spatial, argv); + } + else if(argc != 1) + { + print_use_msg(); + exit(1); + } + + if(split_k <= 1) + { + print_use_msg(); + exit(1); + } + + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); + + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor in_n_c_hi_wi( + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_host_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor wei_k_c_y_x_device_result( + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_device_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x_host_result.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * + wei_k_c_y_x_device_result.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + // reset input to zero + wei_device_buf.SetZero(); + + // do GEMM + auto conv = get_conv_instance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); + + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + if(bwd_weight_workspace_size <= 0) + { + print_use_msg(); + exit(1); + } + + float conv_ave_time = 0.f; + + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + conv->SetWorkSpacePointer(argument.get(), wei_work_space_device_buf.GetDeviceBuffer()); + + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + conv_ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = ck::utils::conv::get_flops( + params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::get_btype( + params.N_, + params.C_, + params.K_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / conv_ave_time; + + float gb_per_sec = num_btype / 1.E6 / conv_ave_time; + + std::cout << "Perf: conv: " << conv_ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + if(do_verification) + { + auto verify_f = [&](const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x_host_result, + out_n_k_ho_wo, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + + wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out_n_k_ho_wo.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in_n_c_hi_wi.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "wei_device(after): ", wei_k_c_y_x_device_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") + << std::endl; + } + + return ck::utils::check_err(wei_k_c_y_x_device_result.mData, + wei_k_c_y_x_host_result.mData) + ? 0 + : 1; + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvBwdWeightInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvBwdWeightInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvBwdWeightInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } + return 0; +} diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000..99b50fefed --- /dev/null +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_fp16 gemm_bias_relu_add_layernorm_xdl_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000..59cbb41005 --- /dev/null +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_fp16.cpp @@ -0,0 +1,425 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_5ary_elementwise.hpp" +#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using C0DataType = F32; +using C1DataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::Relu; +using C1ElementOp = PassThrough; +using ReduceSumOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOps = ck::Tuple; +using DxsOutElementOps = ck::Tuple; + +using DxsGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& bias_n, + const Tensor& c1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + C1_functor c1_element_op, + int M, + int N) +{ + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + // c = activation(c + bias) + c1_functor(c1) + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + AccDataType acc = + static_cast(c_m_n(m, n)) + static_cast(bias_n(n)); + + AccDataType c1 = static_cast(c1_m_n(m, n)); + + c_element_op(acc, acc); + c1_element_op(c1, c1); + acc += c1; + c_m_n(m, n) = static_cast(acc); + } + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + AccDataType c_val = ck::type_convert(c_m_n(m, n)); + AccDataType square_c_val = 0; + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + AccDataType out_acc = 0; + layerNormInst(out_acc, + static_cast(c_m_n(m, n)), + static_cast(mean_m(m)), + static_cast(meanSquare_m(m)), + static_cast(gamma_n(n)), + static_cast(beta_n(n))); + out_m_n(m, n) = static_cast(out_acc); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N + + sizeof(C1DataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_byte / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + ck::index_t StrideC1 = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + bias_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{-5, 5}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + c1_device_buf.ToDevice(c1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto c1_element_op = C1ElementOp{}; + auto dxs_global = + ck::make_tuple(static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer())); + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{N, N}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmBiasAddReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = + gemmReduce.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(c1_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + dxs_in_element_op, + dxs_out_element_op); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument( + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer()), + static_cast(gamma_device_buf.GetDeviceBuffer()), + static_cast(beta_device_buf.GetDeviceBuffer()), + static_cast(layerNorm_device_buf.GetDeviceBuffer()), + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + bias_n, + c1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results layerNorm_m_n", + 1e-2, + 1e-2); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000..05c35477aa --- /dev/null +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp @@ -0,0 +1,379 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_5ary_elementwise.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceSumOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; +using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; +using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOps = ck::Tuple; +using DxsOutElementOps = ck::Tuple; + +using DxsGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + int M, + int N) +{ + using out_type = ck::remove_reference_t; + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{N}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + auto mean_acc = reduceSumOpInst.GetIdentityValue(); + auto square_mean_acc = reduceSumOpInst.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + auto c_val = ck::type_convert(c_m_n(m, n)); + auto square_c_val = reduceSumOpInst.GetIdentityValue(); + + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + float out_f32 = 0; + layerNormInst(out_f32, + static_cast(c_m_n(m, n)), + static_cast(mean_m(m)), + static_cast(meanSquare_m(m)), + static_cast(gamma_n(n)), + static_cast(beta_n(n))); + out_m_n(m, n) = static_cast(out_f32); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = + ck::make_tuple(static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer())); + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{N, N}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = + gemmReduce.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument( + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer()), + static_cast(gamma_device_buf.GetDeviceBuffer()), + static_cast(beta_device_buf.GetDeviceBuffer()), + static_cast(layerNorm_device_buf.GetDeviceBuffer()), + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-3); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt new file mode 100644 index 0000000000..048df3bba4 --- /dev/null +++ b/example/22_cgemm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) diff --git a/example/22_cgemm/cgemm_xdl_fp16.cpp b/example/22_cgemm/cgemm_xdl_fp16.cpp new file mode 100644 index 0000000000..9790164e72 --- /dev/null +++ b/example/22_cgemm/cgemm_xdl_fp16.cpp @@ -0,0 +1,302 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_cgemm_4gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_cgemm.hpp" +#include "gemm_specialization.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using AccDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceCGemmInstance = ck::tensor_operation::device::DeviceCGemm_4Gemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder + 2, // index_t ABlockTransferSrcVectorDim + 8, // index_t ABlockTransferSrcScalarPerVector + 8, // index_t ABlockTransferDstScalarPerVector_AK1 + 1, // index_t ABlockLdsExtraM + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder + 2, // index_t BBlockTransferSrcVectorDim + 8, // index_t BBlockTransferSrcScalarPerVector + 8, // index_t BBlockTransferDstScalarPerVector_BK1 + 1, // index_t BBlockLdsExtraN + 1, // index_t CShuffleMXdlPerWavePerShuffle + 1, // index_t CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock +// clang-format on + +using ReferenceCGemmInstance = ck::tensor_operation::host:: + ReferenceCGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // CGEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k_real(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a_m_k_imag(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n_real(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_imag(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_real_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k_real: " << a_m_k_real.mDesc << std::endl; + std::cout << "a_m_k_imag: " << a_m_k_imag.mDesc << std::endl; + std::cout << "b_k_n_real: " << b_k_n_real.mDesc << std::endl; + std::cout << "b_k_n_imag: " << b_k_n_imag.mDesc << std::endl; + std::cout << "c_m_n_real: " << c_m_n_real_device_result.mDesc << std::endl; + std::cout << "c_m_n_imag: " << c_m_n_imag_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_m_k_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_real.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n_imag.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + auto cgemm = DeviceCGemmInstance{}; + + DeviceMem a_m_k_real_device_buf(sizeof(ADataType) * a_m_k_real.mDesc.GetElementSpace()); + DeviceMem a_m_k_imag_device_buf(sizeof(ADataType) * a_m_k_imag.mDesc.GetElementSpace()); + DeviceMem b_k_n_real_device_buf(sizeof(BDataType) * b_k_n_real.mDesc.GetElementSpace()); + DeviceMem b_k_n_imag_device_buf(sizeof(BDataType) * b_k_n_imag.mDesc.GetElementSpace()); + DeviceMem c_m_n_real_device_buf(sizeof(CDataType) * + c_m_n_real_device_result.mDesc.GetElementSpace()); + DeviceMem c_m_n_imag_device_buf(sizeof(CDataType) * + c_m_n_imag_device_result.mDesc.GetElementSpace()); + DeviceMem workspace_device_buf(cgemm.GetWorkspaceSize(M, N, K, StrideA, StrideB, StrideC)); + + a_m_k_real_device_buf.ToDevice(a_m_k_real.mData.data()); + a_m_k_imag_device_buf.ToDevice(a_m_k_imag.mData.data()); + b_k_n_real_device_buf.ToDevice(b_k_n_real.mData.data()); + b_k_n_imag_device_buf.ToDevice(b_k_n_imag.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto c_element_op = PassThrough{}; + + // do GEMM + auto invoker = cgemm.MakeInvoker(); + auto argument = + cgemm.MakeArgument(static_cast(a_m_k_real_device_buf.GetDeviceBuffer()), + static_cast(a_m_k_imag_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_real_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_imag_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_real_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_imag_device_buf.GetDeviceBuffer()), + static_cast(workspace_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!cgemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_cgemm with the specified compilation parameters does " + "not support this CGEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(8) * M * N * K; + std::size_t num_btype = + std::size_t(2) * + (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << cgemm.GetTypeString() << std::endl; + + c_m_n_real_device_buf.FromDevice(c_m_n_real_device_result.mData.data()); + c_m_n_imag_device_buf.FromDevice(c_m_n_imag_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n_real_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_imag_host_result( + f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + auto ref_cgemm = ReferenceCGemmInstance{}; + auto ref_invoker = ref_cgemm.MakeInvoker(); + + auto ref_argument = ref_cgemm.MakeArgument(a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real_host_result, + c_m_n_imag_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + ck::utils::check_err(c_m_n_real_device_result.mData, + c_m_n_real_host_result.mData, + "Verification error: incorrect results in real part!", + 1e-2f, + 1e-1f); + ck::utils::check_err(c_m_n_imag_device_result.mData, + c_m_n_imag_host_result.mData, + "Verification error: incorrect results in imaginary part!", + 1e-2f, + 1e-1f); + } + + return 0; +} diff --git a/example/23_softmax/CMakeLists.txt b/example/23_softmax/CMakeLists.txt new file mode 100644 index 0000000000..dafe65521a --- /dev/null +++ b/example/23_softmax/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_softmax_blockwise softmax_blockwise.cpp) \ No newline at end of file diff --git a/example/23_softmax/README.md b/example/23_softmax/README.md new file mode 100644 index 0000000000..37c43e9b55 --- /dev/null +++ b/example/23_softmax/README.md @@ -0,0 +1,18 @@ +# Instructions for ```example_softmax_blockwise``` + +## Run ```example_softmax_blockwise``` +```bash +# -D : input 3-d tensor lengths +# -v : verification (0=no, 1=yes) +#arg1: initialization (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value) +#arg2: time kernel (0=no, 1=yes) +example_softmax_blockwise -D 4,128,2048 -v 1 1 1 +``` + +Result +``` +launch_and_time_kernel: grid_dim {64, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 0.0242877 ms, 259.039 GB/s, DeviceReduceSoftmax<256,M_C8_S1,K_C32_S8,InSrcVectorDim_1_InSrcVectorSize_8_OutDstVectorSize_8> +``` diff --git a/example/23_softmax/softmax_blockwise.cpp b/example/23_softmax/softmax_blockwise.cpp new file mode 100644 index 0000000000..39432ac1fe --- /dev/null +++ b/example/23_softmax/softmax_blockwise.cpp @@ -0,0 +1,255 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_base.hpp" +#include "device_softmax.hpp" +#include "host_common_util.hpp" +#include "reference_softmax.hpp" + +#include "reduction_enums.hpp" +#include "reduction_operator_mapping.hpp" + +using namespace ck; +using namespace ck::tensor_operation::device; + +using InDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +constexpr int Rank = 3; +constexpr int NumReduceDim = 1; + +using DeviceInstance = DeviceSoftmax; // OutScalarPerVector + +static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, + {"verify", required_argument, nullptr, 'v'}, + {"help", no_argument, nullptr, '?'}, + {nullptr, 0, nullptr, 0}}; + +class SimpleAppArgs +{ + private: + int option_index = 0; + + public: + std::vector inLengths = {8, 128, 2048}; + std::vector scales = {2.0f, 2.0f}; + + bool do_verification = true; + int init_method = 2; + bool time_kernel = true; + + public: + void show_usage(const char* cmd) + { + std::cout << "Usage of " << cmd << std::endl; + std::cout << "--inLengths or -D, comma separated list of input tensor dimension lengths" + << std::endl; + std::cout << "--verify or -v, 1/0 to indicate whether to verify the reduction result by " + "comparing with the host-based reduction" + << std::endl; + std::cout << "Arg1 -- init method (0=no init, 1=single integer value, 2=scope integer " + "value, 3=decimal value)" + << std::endl; + std::cout << "Arg2 -- time kernel (0=no, 1=yes)" << std::endl; + }; + + int processArgs(int argc, char* argv[]) + { + using ck::host_common::getTypeValuesFromString; + + int ch; + + while(1) + { + ch = getopt_long(argc, argv, "D:v:l:", long_options, &option_index); + if(ch == -1) + break; + switch(ch) + { + case 'D': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + inLengths = getTypeValuesFromString(optarg); + break; + case 'v': + if(!optarg) + throw std::runtime_error("Invalid option format!"); + + do_verification = static_cast(std::atoi(optarg)); + break; + case '?': + if(std::string(long_options[option_index].name) == "help") + { + show_usage(argv[0]); + return (-1); + }; + break; + default: show_usage(argv[0]); return (-1); + }; + }; + + if(optind + 2 > argc) + throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); + + init_method = std::atoi(argv[optind++]); + time_kernel = static_cast(std::atoi(argv[optind])); + + if(scales.empty()) + { + scales.push_back(1.0f); + scales.push_back(0.0f); + }; + + return (0); + }; +}; + +int main(int argc, char* argv[]) +{ + // Example: batched gemm C[G, M, N] applies max/sum reduction along N internally + const std::vector invariantDims{0, 1}; + const std::vector reduceDims{2}; + + SimpleAppArgs args; + + if(argc > 1) + { + if(args.processArgs(argc, argv) < 0) + return (-1); + }; + + Tensor in(args.inLengths); + Tensor out_ref(args.inLengths); + Tensor out(args.inLengths); + + auto inStrides = in.mDesc.GetStrides(); + auto outStrides = out.mDesc.GetStrides(); + + AccDataType alpha = args.scales[0]; + AccDataType beta = args.scales[1]; + + std::size_t num_thread = 1; + + if(args.do_verification) + { + switch(args.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + if(beta != 0.0f) + out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); + } + + if(beta != 0.0f) + for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) + out.mData[i] = out_ref.mData[i]; + }; + // std::cout << "beta = " << beta << std::endl; + // LogRangeAsType(std::cout << "tensor in: " , in.mData, ",") << std::endl; + // LogRangeAsType(std::cout << "tensor prior out: " , out.mData, ",") << std::endl; + + // these buffers are usually provided by the user application + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + + in_dev.ToDevice(in.mData.data()); + + if(beta != 0.0f) + out_dev.ToDevice(out.mData.data()); + + if(args.do_verification) + { + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + ReferenceInstance ref; + auto ref_arg = ref.MakeArgument(in, out_ref, alpha, beta, Rank, reduceDims); + auto invoker = ref.MakeInvoker(); + invoker.Run(ref_arg); + // LogRangeAsType(std::cout << "tensor out_ref: ", out_ref.mData, ",") << std::endl; + }; + + std::vector i_inLengths; + std::vector i_inStrides; + + i_inLengths.assign(args.inLengths.begin(), args.inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + + auto device_instance = DeviceInstance{}; + + auto argument_ptr = device_instance.MakeArgumentPointer(i_inLengths, + i_inStrides, + reduceDims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + std::cout + << "The runtime parameters seems not supported by the DeviceReduce instance, exiting!" + << std::endl; + return 1; + }; + + std::string instance_name = device_instance.GetTypeString(); + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + + bool pass = true; + if(args.do_verification) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + out_dev.FromDevice(out.mData.data()); + // LogRangeAsType(std::cout << "tensor out: " , out.mData, ",") << std::endl; + pass = pass && ck::utils::check_err(out.mData, out_ref.mData); + }; + + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, args.time_kernel}); + + std::size_t num_bytes = + in.mDesc.GetElementSize() * sizeof(InDataType) + + (beta == 0.0f ? 1 : 2) * out.mDesc.GetElementSize() * sizeof(OutDataType); + + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << instance_name + << std::endl; + + return (pass ? 0 : 1); +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 5f04125305..2b80fc44a2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform @@ -19,17 +20,26 @@ include_directories(BEFORE add_custom_target(examples) -function(add_example_executable EXAMPLE_NAME) +function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${ARGN}) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) +endfunction(add_example_executable EXAMPLE_NAME) + +function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + message("adding example ${EXAMPLE_NAME}") + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE host_tensor) add_dependencies(examples ${EXAMPLE_NAME}) -endfunction(add_example_executable EXAMPLE_NAME) +endfunction(add_example_executable_no_testing EXAMPLE_NAME) add_subdirectory(01_gemm) add_subdirectory(02_gemm_alpha_beta) add_subdirectory(03_gemm_bias_relu) -add_subdirectory(04_gemm_bias_relu_add) +add_subdirectory(04_gemm_add_add_fastgelu) add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(09_convnd_fwd) @@ -38,7 +48,12 @@ add_subdirectory(11_conv2d_bwd_weight) add_subdirectory(12_reduce) add_subdirectory(13_pool2d_fwd) add_subdirectory(14_gemm_xdl_requant_relu_requant) -add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(15_grouped_gemm) add_subdirectory(16_gemm_reduce) +add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(18_batched_gemm_reduce) +add_subdirectory(19_binary_elementwise) +add_subdirectory(20_convnd_bwd_weight_xdl) +add_subdirectory(21_gemm_layernorm) +add_subdirectory(22_cgemm) +add_subdirectory(23_softmax) diff --git a/include/ck/config.hpp b/include/ck/config.hpp index 6b0710a795..fbea8ecd40 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + #ifndef CK_CONFIG_AMD_HPP #define CK_CONFIG_AMD_HPP @@ -76,6 +79,12 @@ #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #endif +#if defined(__gfx90a__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + // inline asm #define CK_USE_AMD_INLINE_ASM 1 @@ -91,10 +100,11 @@ // experimental feature: static tensor descriptor #define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 -// experimental feature: buffer load/store/atomic-add OOB trick +// experimental feature: buffer load/store/atomic-add/ OOB trick #define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 // experimental feature: in-regsiter sub-dword transpose #define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 @@ -109,6 +119,10 @@ // experimental feature: use __builtin_memcpy instead of union to do bit_cast #define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 +// experimental feature: optimize for inter-wave scheduling policy +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 0 +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 + // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be // thread-invariant, otherwise it's a bug @@ -128,19 +142,29 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 1 -// workaround for verification failure ConvNd forward -// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 -#define CK_WORKAROUND_GITHUB_135 1 - namespace ck { enum struct InMemoryDataOperationEnum { Set, AtomicAdd, + AtomicMax, Add }; +template +struct InMemoryDataOperationEnumSequence +{ + static constexpr int mSize = sizeof...(Is); + + __host__ __device__ static constexpr InMemoryDataOperationEnum At(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set}; + return mData[I]; + } +}; + // TODO: no longer needed, remove this enum struct ActivTypeEnum { diff --git a/include/ck/hip_version.hpp.in b/include/ck/hip_version.hpp.in deleted file mode 100644 index 4290ef7e0d..0000000000 --- a/include/ck/hip_version.hpp.in +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -// "_PACKAGE_" to avoid name contentions: the macros like -// HIP_VERSION_MAJOR are defined in HIP_VERSION.h. -// clang-format off -#define CK_HIP_PACKAGE_VERSION_MAJOR @CK_HIP_VERSION_MAJOR@ -#define CK_HIP_PACKAGE_VERSION_MINOR @CK_HIP_VERSION_MINOR@ -#define CK_HIP_PACKAGE_VERSION_PATCH @CK_HIP_VERSION_PATCH@ -// clang-format on - -#ifndef CK_HIP_PACKAGE_VERSION_MAJOR -#define CK_HIP_PACKAGE_VERSION_MAJOR 0 -#endif -#ifndef CK_HIP_PACKAGE_VERSION_MINOR -#define CK_HIP_PACKAGE_VERSION_MINOR 0 -#endif -#ifndef CK_HIP_PACKAGE_VERSION_PATCH -#define CK_HIP_PACKAGE_VERSION_PATCH 0 -#endif -// 3 decimal digits for major and minor, 6 digits for patch number. -// Max number is 999,999,999999 == 0xE8,D4A5,0FFF that fits into 64-bit math. -#if CK_HIP_PACKAGE_VERSION_MAJOR > 999 || CK_HIP_PACKAGE_VERSION_MAJOR > 999 || \ - CK_HIP_PACKAGE_VERSION_PATCH > 999999 -#error "Too big HIP version number(s)" -#endif -#define CK_HIP_PACKAGE_VERSION_FLAT \ - ((CK_HIP_PACKAGE_VERSION_MAJOR * 1000ULL + CK_HIP_PACKAGE_VERSION_MINOR) * 1000000 + \ - CK_HIP_PACKAGE_VERSION_PATCH) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp new file mode 100644 index 0000000000..74b20acecd --- /dev/null +++ b/include/ck/host_utility/device_prop.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include +#include + +namespace ck { + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + + // https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + static std::map device_name_map = { + {"Ellesmere", "gfx803"}, + {"Baffin", "gfx803"}, + {"RacerX", "gfx803"}, + {"Polaris10", "gfx803"}, + {"Polaris11", "gfx803"}, + {"Tonga", "gfx803"}, + {"Fiji", "gfx803"}, + {"gfx800", "gfx803"}, + {"gfx802", "gfx803"}, + {"gfx804", "gfx803"}, + {"Vega10", "gfx900"}, + {"gfx901", "gfx900"}, + {"10.3.0 Sienna_Cichlid 18", "gfx1030"}, + }; + + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + + auto match = device_name_map.find(name); + if(match != device_name_map.end()) + return match->second; + return name; +} + +} // namespace ck diff --git a/include/ck/options.hpp b/include/ck/options.hpp new file mode 100644 index 0000000000..82c604f82b --- /dev/null +++ b/include/ck/options.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define CK_TIME_KERNEL 1 diff --git a/include/ck/stream_config.hpp b/include/ck/stream_config.hpp new file mode 100644 index 0000000000..3e80b4c892 --- /dev/null +++ b/include/ck/stream_config.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; +}; diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 8787abd6ba..e62255ff48 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -136,7 +136,11 @@ struct TensorAdaptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorAdaptor() = default; +#else + __host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {} +#endif __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 9cd51c61d6..0ca4f6e24d 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -111,7 +111,14 @@ struct TensorDescriptor using ElementSize = remove_cv_t; public: +#if 0 // workaround compiler complaint about constexpr __host__ __device__ constexpr TensorDescriptor() = default; +#else + __host__ __device__ constexpr TensorDescriptor() + : transforms_{}, element_size_{}, element_space_size_{} + { + } +#endif __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, ElementSpaceSize element_space_size) diff --git a/include/ck/tensor_description/tensor_descriptor_helper.hpp b/include/ck/tensor_description/tensor_descriptor_helper.hpp index ad75f9245e..ddc0ede404 100644 --- a/include/ck/tensor_description/tensor_descriptor_helper.hpp +++ b/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -1,6 +1,4 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP -#define CK_TENSOR_DESCRIPTOR_HELPER_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "multi_index_transform_helper.hpp" @@ -35,6 +33,12 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt } #endif +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template ::type = false> @@ -68,10 +72,10 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple{}, Number<1>{}); + const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); #else const auto element_space_size = - calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); + calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); #endif return TensorDescriptor, @@ -82,9 +86,12 @@ __host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple& lengths) @@ -100,7 +107,7 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; - const auto element_space_size = container_reduce(lengths, math::multiplies{}, Number<1>{}); + const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); return TensorDescriptor, remove_cv_t, @@ -110,6 +117,12 @@ make_naive_tensor_descriptor_packed(const Tuple& lengths) element_space_size}; } +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) Number<> template __host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) @@ -146,4 +159,3 @@ make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align ali } } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp similarity index 97% rename from include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp rename to include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp index 0a7b8486f4..f7fa867e16 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -1,10 +1,8 @@ -#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP -#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_adaptor.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" -#include "threadwise_contraction_dlops.hpp" +#include "threadwise_tensor_slice_transfer_v4r1.hpp" +#include "threadwise_contraction_dl.hpp" namespace ck { @@ -41,7 +39,7 @@ template ::type = false> -struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ @@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B "wrong!"); // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); static_assert(BM0 == 2 && BN0 == 2, "wrong"); } @@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); constexpr auto threadwise_contraction = - ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< FloatA, FloatB, FloatC, @@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 7986175702..b93d5ff839 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -3,10 +3,26 @@ #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" #include "tensor_adaptor.hpp" +#include "thread_group.hpp" namespace ck { -template {}; static constexpr auto I3 = Number<3>{}; + using ThisThreadBlock = ThisThreadBlock; + static constexpr index_t WaveSize = get_warp_size(); static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); @@ -53,7 +71,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __device__ static auto GetWaveIdx() { - const index_t thread_id = ThreadGroup::GetThreadId(); + const index_t thread_id = ThisThreadBlock::GetThreadId(); constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), @@ -120,8 +138,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ThreadGroup::GetNumOfThread() == MWaves * NWaves * WaveSize, - "ThreadGroup::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); @@ -299,7 +317,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 }); } - private: + protected: // A[M0, M1, M2, KPerThread] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); @@ -336,4 +354,232 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; +// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro +// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in +// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the +// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 +template +struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + using Base::a_block_desc_m0_m1_m2_k; + using Base::A_K1; + using Base::b_block_desc_n0_n1_n2_k; + using Base::B_K1; + using Base::c_thread_buf_; + using Base::c_thread_desc_; + using Base::CalculateAThreadOriginDataIndex; + using Base::CalculateBThreadOriginDataIndex; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + + // 2-wave optimized blockwise gemm + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, k), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, k), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except + // the first, as we can shorten non-MAC cluster a bit and there's no observable negative + // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids + // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the + // chance of latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { + asm volatile("s_barrier" ::); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + + protected: + // A[M0, M1, M2, KPerInnerLoop] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, KPerInnerLoop] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + +#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +}; + +template +constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp deleted file mode 100644 index 8306fa93ff..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v4r1.hpp +++ /dev/null @@ -1,169 +0,0 @@ -#pragma once -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v3r1.hpp" - -namespace ck { - -// this version does following things to avoid scratch memory issue -// 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v4r1 -{ - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; - - using Index = MultiIndex; - - __device__ constexpr BlockwiseTensorSliceTransfer_v4r1( - const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const SrcElementwiseOperation& src_element_op, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const DstElementwiseOperation& dst_element_op) - : threadwise_transfer_(src_desc, - make_zero_multi_index(), - src_element_op, - dst_desc, - make_zero_multi_index(), - dst_element_op) - - { - static_assert(nDim == remove_cvref_t::GetNumOfDimension() && - nDim == remove_cvref_t::GetNumOfDimension() && - nDim == ThreadClusterLengths::Size() && - nDim == ThreadClusterArrangeOrder::Size() && - nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), - "wrong! nDim not consistent"); - - static_assert( - is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); - - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - - const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; - - threadwise_transfer_.SetSrcSliceOrigin(src_desc, - src_block_slice_origin + thread_data_idx_begin); - threadwise_transfer_.SetDstSliceOrigin(dst_desc, - dst_block_slice_origin + thread_data_idx_begin); - } - } - - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - Number thread_scratch_id = Number{}) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); - } - } - - template - __device__ void RunWrite(const DstDesc& dst_desc, - DstBuffer& dst_buf, - Number thread_scratch_id = Number{}) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); - } - } - - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc& dst_desc, - DstBuffer& dst_buf, - Number thread_scratch_id) - { - RunRead(src_desc, src_buf, thread_scratch_id); - RunWrite(dst_desc, dst_buf, thread_scratch_id); - } - - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); - } - } - - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); - } - } - - private: - static constexpr auto thread_cluster_desc_ = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - - using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v3r1; - - ThreadwiseTransfer threadwise_transfer_; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp index 93fe5da723..e8ec164364 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1 } } - template - __device__ void - RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) { if(BlockSize == thread_cluster_desc_.GetElementSize() or get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks); + threadwise_transfer_.RunRead(src_desc, src_buf); } } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp deleted file mode 100644 index 3bf86d6785..0000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r1.hpp +++ /dev/null @@ -1,133 +0,0 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R1_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_tensor_slice_transfer_v6r1.hpp" - -namespace ck { - -// this version does following things to avoid scratch memory issue -// 1. Use StaticallyIndexedArray instead of C array for thread buffer -// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor -// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r1 -{ - static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; - - using Index = MultiIndex; - - __device__ constexpr BlockwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, - const Index& src_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) - : threadwise_transfer_(src_desc, - make_zero_multi_index(), - dst_desc, - make_zero_multi_index(), - element_op) - - { - static_assert(nDim == remove_cvref_t::GetNumOfDimension() && - nDim == remove_cvref_t::GetNumOfDimension() && - nDim == ThreadClusterLengths::Size() && - nDim == ThreadClusterArrangeOrder::Size() && - nDim == DimAccessOrder::Size(), - "wrong! nDim not consistent"); - - static_assert( - is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); - - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); - - const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; - - threadwise_transfer_.SetSrcSliceOrigin(src_desc, - src_block_slice_origin + thread_data_idx_begin); - threadwise_transfer_.SetDstSliceOrigin(dst_desc, - dst_block_slice_origin + thread_data_idx_begin); - } - } - - template - __device__ void Run(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - const DstDesc& dst_desc, - DstBuffer& dst_buf) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); - } - } - - __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); - } - } - - __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) - { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) - { - threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); - } - } - - private: - static constexpr auto thread_cluster_desc_ = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - - using ThreadwiseTransfer = - ThreadwiseTensorSliceTransfer_v6r1; - - ThreadwiseTransfer threadwise_transfer_; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index cc452b5e5c..8580b9ea4a 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -45,7 +45,9 @@ template + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithNanCheck> struct PartitionedBlockwiseReduction { static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), @@ -62,8 +64,6 @@ struct PartitionedBlockwiseReduction static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using Accumulation = detail::AccumulateWithNanCheck; - template __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) { @@ -113,13 +113,16 @@ struct PartitionedBlockwiseReduction // 3) in_out_value/in_out_index is the input data in vgpr from each thread // 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread // clang-format on -template +template < + typename AccDataType, + typename IndexDataType, + index_t BlockSize, + typename ThreadClusterLengths_M_K, + typename ThreadClusterArrangeOrder, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> struct PartitionedBlockwiseReductionWithIndex { static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), @@ -136,9 +139,6 @@ struct PartitionedBlockwiseReductionWithIndex static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - using Accumulation = - detail::AccumulateWithIndexAndNanCheck; - // This interface accumulates on both data values and indices template __device__ static void Reduce(BufferType& work_val_buffer, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp index 0baebc5e3b..1f0ad3e35a 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -14,7 +14,7 @@ namespace ck { template ::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; @@ -54,7 +54,7 @@ struct ThreadGroupTensorSliceTransfer_v6r1 "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp similarity index 71% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp index 575a015802..121ddf12ad 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R2_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,10 +11,10 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. It does not keep reference to tensor descriptor // 3. Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r2 +struct ThreadGroupTensorSliceTransfer_v6r2 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, - const Index& src0_block_slice_origin, - const Src1Desc& src1_desc, - const Index& src1_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) : threadwise_transfer_(src0_desc, make_zero_multi_index(), src1_desc, @@ -64,17 +62,17 @@ struct BlockwiseTensorSliceTransfer_v6r2 "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( - make_multi_index(get_thread_local_1d_id())); + make_multi_index(ThreadGroup::GetThreadId())); const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; @@ -95,8 +93,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 const DstDesc& dst_desc, DstBuffer& dst_buf) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); } @@ -104,8 +102,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); } @@ -113,8 +111,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); } @@ -122,8 +120,8 @@ struct BlockwiseTensorSliceTransfer_v6r2 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -154,4 +152,3 @@ struct BlockwiseTensorSliceTransfer_v6r2 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp similarity index 72% rename from include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp rename to include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp index 4a1d82000a..ca5db90f30 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -1,6 +1,4 @@ -#ifndef CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP -#define CK_BLOCKWISE_TENSOR_SLICE_TRANSFER_V6R3_HPP - +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" @@ -13,10 +11,10 @@ namespace ck { // 1. Use StaticallyIndexedArray instead of C array for thread buffer // 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor // 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate -template -struct BlockwiseTensorSliceTransfer_v6r3 +struct ThreadGroupTensorSliceTransfer_v6r3 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); - static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; using Index = MultiIndex; - __device__ constexpr BlockwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, - const Index& src0_block_slice_origin, - const Src1Desc& src1_desc, - const Index& src1_block_slice_origin, - const Src2Desc& src2_desc, - const Index& src2_block_slice_origin, - const DstDesc& dst_desc, - const Index& dst_block_slice_origin, - const ElementwiseOperation& element_op) + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) : threadwise_transfer_(src0_desc, make_zero_multi_index(), src1_desc, @@ -72,14 +70,14 @@ struct BlockwiseTensorSliceTransfer_v6r3 "wrong! nDim not consistent"); static_assert( - is_same{}, + is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), - "wrong! BlockSize too small"); + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( make_multi_index(get_thread_local_1d_id())); @@ -107,8 +105,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 const DstDesc& dst_desc, DstBuffer& dst_buf) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.Run( src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); @@ -117,8 +115,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); } @@ -126,8 +124,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); } @@ -135,8 +133,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); } @@ -144,8 +142,8 @@ struct BlockwiseTensorSliceTransfer_v6r3 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) { - if(BlockSize == thread_cluster_desc_.GetElementSize() or - get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); } @@ -179,4 +177,3 @@ struct BlockwiseTensorSliceTransfer_v6r3 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000..d499eee4c5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -0,0 +1,169 @@ +#pragma once + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_tensor_slice_transfer_v7.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs); + } + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp new file mode 100644 index 0000000000..60995e068c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardWeightSpecialization +{ + Default, + Filter1x1Stride1Pad0, + Filter1x1Pad0, + OddC, +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp new file mode 100644 index 0000000000..c093f5028c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp @@ -0,0 +1,332 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_5ary_Elementwise_1d.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct Device5AryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + using Gridwise5AryEltwise = Gridwise5AryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_c, + const DDataType* p_d, + const EDataType* p_e, + FDataType* p_f, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + const std::vector& d_strides, + const std::vector& e_strides, + const std::vector& f_strides, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + p_d_(p_d), + p_e_(p_e), + p_f_(p_f), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), + d_strides_(d_strides), + e_strides_(e_strides), + f_strides_(f_strides), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); + d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_); + e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_); + f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + const CDataType* p_c_; + const DDataType* p_d_; + const EDataType* p_e_; + FDataType* p_f_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + DGridDesc_M d_grid_desc_m_; + EGridDesc_M e_grid_desc_m_; + FGridDesc_M f_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; + std::vector d_strides_; + std::vector e_strides_; + std::vector f_strides_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_5ary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.p_d_, + arg.p_e_, + arg.p_f_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, + arg.d_grid_desc_m_, + arg.e_grid_desc_m_, + arg.f_grid_desc_m_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector)) + return false; + + return true; + }; + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_c, + const DDataType* p_d, + const EDataType* p_e, + FDataType* p_f, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + std::vector d_strides, + std::vector e_strides, + std::vector f_strides, + ElementwiseFunctor functor) + { + return Argument{p_a, + p_b, + p_c, + p_d, + p_e, + p_f, + lengths, + a_strides, + b_strides, + c_strides, + d_strides, + e_strides, + f_strides, + functor}; + } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_c, + const void* p_d, + const void* p_e, + void* p_f, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + std::vector d_strides, + std::vector e_strides, + std::vector f_strides, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_d), + static_cast(p_e), + static_cast(p_f), + lengths, + a_strides, + b_strides, + c_strides, + d_strides, + e_strides, + f_strides, + functor); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index cf48695ad0..809eba5578 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -1,8 +1,9 @@ -#ifndef DEVICE_BASE_HPP -#define DEVICE_BASE_HPP +#pragma once #include +#include "stream_config.hpp" + namespace ck { namespace tensor_operation { namespace device { @@ -14,6 +15,8 @@ struct BaseArgument BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; }; struct BaseInvoker @@ -22,7 +25,10 @@ struct BaseInvoker BaseInvoker(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default; - virtual float Run(const BaseArgument*, int = 1) = 0; + virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) + { + return float{0}; + } virtual ~BaseInvoker() {} }; @@ -33,8 +39,16 @@ struct BaseOperator BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; - virtual bool IsSupportedArgument(const BaseArgument*) = 0; - virtual std::string GetTypeString() const = 0; + virtual bool IsSupportedArgument(const BaseArgument*) { return false; } + virtual std::string GetTypeString() const { return ""; } + + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + + virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } virtual ~BaseOperator() {} }; @@ -42,4 +56,3 @@ struct BaseOperator } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index a90bc44fdf..2379719fb9 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -17,11 +17,12 @@ namespace device { template (compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx))); + static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_ds_grid(In) = p_ds_grid(In) + d_batch_offset; + }); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_c_grid + c_batch_offset, - p_d0_grid + d0_batch_offset, - p_d1_grid + d1_batch_offset, + p_ds_grid, p_shared, a_element_op, b_element_op, c_element_op, - d1_element_op, + dxs_in_element_op, + dxs_out_element_op, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, c_grid_desc_mblock_mperblock_nblock_nperblock, @@ -90,13 +92,13 @@ __global__ void ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = p_d0_grid; - ignore = p_d1_grid; + ignore = p_ds_grid; ignore = batch_count; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; - ignore = d1_element_op; + ignore = dxs_in_element_op; + ignore = dxs_out_element_op; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; @@ -106,6 +108,9 @@ __global__ void #endif // end of if defined (defined(__gfx908__) || defined(__gfx90a__)) } +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. template -struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce + index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct DeviceBatchedGemmReduce_Xdl_CShuffle + : public DeviceGemmReduce { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; @@ -461,56 +470,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_insert_transform(batch_count), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto globalblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return globalblockid_to_m0_n0_block_cluster_adaptor; - } - struct ComputeBasePtrOfStridedBatch { ComputeBasePtrOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, - index_t BatchStrideD0, - index_t BatchStrideD1) + index_t BatchStrideD) : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC), - BatchStrideD0_(BatchStrideD0), - BatchStrideD1_(BatchStrideD1) + BatchStrideD_(BatchStrideD) { } @@ -529,22 +498,20 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(BatchStrideC_); } - __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const + template + __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, + Number reduction_idx) const { - return g_idx * static_cast(BatchStrideD0_); - } - - __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideD1_); + // TODO - Support sequence of StrideD in MakeArgument() + (void)reduction_idx; + return g_idx * static_cast(BatchStrideD_); } private: index_t BatchStrideA_; index_t BatchStrideB_; index_t BatchStrideC_; - index_t BatchStrideD0_; - index_t BatchStrideD1_; + index_t BatchStrideD_; }; // GridwiseGemm @@ -554,15 +521,15 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; - - using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; // Argument struct Argument : public BaseArgument @@ -610,8 +576,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), + type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), + type_convert(d_grid_desc_m_.GetElementSpaceSize())}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, - d1_element_op_{d1_element_op} + dxs_in_element_op_{dxs_in_element_op}, + dxs_out_element_op_{dxs_out_element_op} { - if(GridwiseGemm::CheckValidity( - a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -655,8 +623,6 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, + typename GridwiseGemm::DefaultBlock2CTileMap, true>; - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.p_d1_grid_, - arg.BatchCount_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d1_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } else { @@ -769,48 +741,52 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce, + typename GridwiseGemm::DefaultBlock2CTileMap, false>; - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.p_d1_grid_, - arg.BatchCount_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d1_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.compute_base_ptr_of_batch_, - arg.block_2_ctile_map_); + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); } - return 0; + return elapsed_time; } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -822,8 +798,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - void* p_d0, - void* p_d1, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - D1ElementwiseOperation d1_element_op, - index_t BatchCount) override + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + index_t BatchCount) override { + DPtrsGlobal dxs_tuple = *(static_cast(p_dxs)); return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - static_cast(p_d0), - static_cast(p_d1), + dxs_tuple, MRaw, NRaw, KRaw, @@ -909,7 +888,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_insert_transform(batch_count), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto globalblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return globalblockid_to_m0_n0_block_cluster_adaptor; - } - struct ComputePtrOffsetOfStridedBatch { ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, @@ -354,7 +316,7 @@ struct DeviceBatchedGemmXdl using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Argument struct Argument : public BaseArgument @@ -384,23 +346,25 @@ struct DeviceBatchedGemmXdl DeviceBatchedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB)}, c_grid_desc_m_n_{DeviceBatchedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC)}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - compute_ptr_offset_of_batch_{a_grid_desc_k0_m_k1_.GetElementSpaceSize(), - b_grid_desc_k0_n_k1_.GetElementSpaceSize(), - c_grid_desc_m_n_.GetElementSpaceSize()}, - block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{ + type_convert(a_grid_desc_k0_m_k1_.GetElementSpaceSize()), + type_convert(b_grid_desc_k0_n_k1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize())}, + block_2_ctile_map_{ + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, M01_{M01}, N01_{N01}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, M01, N01); } } @@ -427,7 +391,7 @@ struct DeviceBatchedGemmXdl { using Argument = DeviceBatchedGemmXdl::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -445,15 +409,14 @@ struct DeviceBatchedGemmXdl if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } const index_t grid_size = - GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -476,8 +439,8 @@ struct DeviceBatchedGemmXdl remove_reference_t, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -510,8 +473,8 @@ struct DeviceBatchedGemmXdl remove_reference_t, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -533,9 +496,10 @@ struct DeviceBatchedGemmXdl } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -550,8 +514,7 @@ struct DeviceBatchedGemmXdl return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp new file mode 100644 index 0000000000..34b3a59c74 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp @@ -0,0 +1,234 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_binary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBinaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using GridwiseBinEltwise = GridwiseBinaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + CDataType* p_c_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_binary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + lengths, + a_strides, + b_strides, + c_strides, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "MPerThread = " << MPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp new file mode 100644 index 0000000000..ad4fde750f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceCGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC) = 0; +}; + +template +using DeviceCGemmPtr = std::unique_ptr< + DeviceCGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000..df2805b886 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -0,0 +1,972 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm.hpp" +#include "device_cgemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "binary_element_wise_operation.hpp" +#include "gridwise_binary_elementwise_1d.hpp" +#include "tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ALayout, + typename BLayout, + typename CLayout, + typename ADataType, + typename BDataType, + typename CDataType, + typename GemmAccDataType, + typename CShuffleDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1, + index_t BK1, + index_t MPerXDL, + index_t NPerXDL, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceCGemm_4Gemm_Xdl_CShuffle + : public DeviceCGemm +{ + using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto MPerThread = Number<4>{}; + static constexpr auto AScalarPerVector = Number<4>{}; + static constexpr auto BScalarPerVector = Number<4>{}; + static constexpr auto CScalarPerVector = Number<4>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto M = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(M, loop_step) - M; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(M, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& strides, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<2>{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid_real, + const ADataType* p_a_grid_imag, + const BDataType* p_b_grid_real, + const BDataType* p_b_grid_imag, + CDataType* p_c_grid_real, + CDataType* p_c_grid_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_real_{p_a_grid_real}, + p_a_grid_imag_{p_a_grid_imag}, + p_b_grid_real_{p_b_grid_real}, + p_b_grid_imag_{p_b_grid_imag}, + p_c_grid_real_{p_c_grid_real}, + p_c_grid_imag_{p_c_grid_imag}, + p_aux_grid_{p_workspace}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + + const index_t grid_size = block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); + + if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {StrideC, I1}, grid_size, BlockSize); + } + else if constexpr(is_same::value) + { + c_grid_desc_m_ = + DeviceOp::MakeDescriptor_M({MRaw, NRaw}, {I1, StrideC}, grid_size, BlockSize); + } + + p_aux_2_grid_ = p_workspace + c_grid_desc_m_n_.GetElementSpaceSize(); + } + + // private: + const ADataType* p_a_grid_real_; + const ADataType* p_a_grid_imag_; + const BDataType* p_b_grid_real_; + const BDataType* p_b_grid_imag_; + CDataType* p_c_grid_real_; + CDataType* p_c_grid_imag_; + CDataType* p_aux_grid_; + CDataType* p_aux_2_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M c_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + using Add = ck::tensor_operation::element_wise::Add; + using Subtract = ck::tensor_operation::element_wise::Subtract; + using GridwiseBinAdd = GridwiseBinaryElementwise_1D; + using GridwiseBinSubtract = GridwiseBinaryElementwise_1D; + const auto add_kernel = kernel_binary_elementwise_1d; + const auto subtract_kernel = kernel_binary_elementwise_1d; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Add{}); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_real_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_imag_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel(stream_config, + subtract_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_real_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Subtract{}); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_real_, + arg.p_b_grid_imag_, + arg.p_aux_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_imag_, + arg.p_b_grid_real_, + arg.p_aux_2_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel(stream_config, + add_kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_aux_grid_, + arg.p_aux_2_grid_, + arg.p_c_grid_imag_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + arg.c_grid_desc_m_, + Add{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a_real, + const ADataType* p_a_imag, + const BDataType* p_b_real, + const BDataType* p_b_imag, + CDataType* p_c_real, + CDataType* p_c_imag, + CDataType* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_real, + p_a_imag, + p_b_real, + p_b_imag, + p_c_real, + p_c_imag, + p_workspace, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a_real), + static_cast(p_a_imag), + static_cast(p_b_real), + static_cast(p_b_imag), + static_cast(p_c_real), + static_cast(p_c_imag), + static_cast(p_workspace), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceCGemm_4Gemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } + + std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + [[maybe_unused]] index_t KRaw, + [[maybe_unused]] index_t StrideA, + [[maybe_unused]] index_t StrideB, + index_t StrideC) override + { + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC); + + return 2 * sizeof(CDataType) * c_grid_desc_m_n.GetElementSpaceSize(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 466e6ad89f..8404f4c266 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -11,7 +11,7 @@ #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r4r2.hpp" +#include "gridwise_gemm_xdlops_bwd_weight.hpp" namespace ck { namespace tensor_operation { @@ -81,6 +81,22 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ static constexpr auto K1Number = Number{}; static constexpr auto GemmK1Number = K1Number; + static constexpr auto N1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, @@ -125,27 +141,51 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const index_t N0 = N / N1Number; + const index_t GemmK0Total = N0 * Ho * Wo; + + const index_t GemmK0S = + math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock; + const index_t GemmK0Pad = GemmKBatch * GemmK0S; + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); + + const auto out_n0_ho_wo_k_n1_grid_desc = + transform_tensor_descriptor(out_n_ho_wo_k_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0total_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)), + make_pass_through_transform(K), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0total_gemmm_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + out_gemmk0pad_gemmm_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); // B: input tensor const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( @@ -167,26 +207,50 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_gemmktotal_gemmn_grid_desc = + const auto in_n0_y_ho_x_wo_c_n1_grid_desc = transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Y), + make_pass_through_transform(Ho), + make_pass_through_transform(X), + make_pass_through_transform(Wo), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 6>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_n0_y_ho_x_wo_c_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C)), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0total_gemmn_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + in_gemmk0pad_gemmn_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); // C: weight tensor const auto wei_gemmm_gemmn_grid_desc = @@ -205,7 +269,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ using CGridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -233,6 +297,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, @@ -241,12 +308,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -274,6 +346,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, @@ -282,10 +357,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CBlockTransferScalarPerVector_NWaveNPerXdl, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); @@ -353,17 +433,16 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_m_n_, - M01_, - N01_)) + block_2_ctile_map_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); } } @@ -415,20 +494,21 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { ShowInfo(arg); + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( - "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting"); } - const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); @@ -437,56 +517,35 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ float ave_time = 0; const auto Run = [&](const auto& kernel) { - if(nrepeat > 0) - { - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); - if(kbatch > 1 || nrepeat <= 0) - { - hipGetErrorString(hipMemset( - arg.p_c_grid_, - 0, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); - - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; if(has_main_k0_block_loop) { if(kbatch == 1) { - const auto kernel = kernel_gemm_xdlops_v2r4r2< + const auto kernel = kernel_gemm_xdlops_bwd_weight< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -503,7 +562,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< + const auto kernel = kernel_gemm_xdlops_bwd_weight< GridwiseGemmAtomicAdd, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -523,7 +582,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ { if(kbatch == 1) { - const auto kernel = kernel_gemm_xdlops_v2r4r2< + const auto kernel = kernel_gemm_xdlops_bwd_weight< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -540,7 +599,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< + const auto kernel = kernel_gemm_xdlops_bwd_weight< GridwiseGemmAtomicAdd, ADataType, // TODO: distiguish A/B datatype CDataType, @@ -560,9 +619,10 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -582,6 +642,12 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return false; } + // unmerge N to N0 and N1, where N1 equals to K1 + if(!(arg.Conv_N_ % K1 == 0)) + { + return false; + } + // vector store C matrix into global memory if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) { @@ -592,8 +658,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 5606dad034..83953e59bd 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -486,13 +486,16 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_)) + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - block_2_ctile_map_container_.push_back( - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01)); + block_2_ctile_map_container_.push_back(block_2_ctile_map); } } } @@ -531,7 +534,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -572,15 +575,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], arg.c_grid_desc_m_n_container_[i], - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_container_[i])) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = - GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); @@ -602,8 +604,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K true>; ave_time += launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -635,8 +637,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K false>; ave_time += launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -655,9 +657,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -697,13 +700,12 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // Gridwise GEMM size - for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], arg.c_grid_desc_m_n_container_[i], - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_container_[i])) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index 6648929cd5..cc1c2cb2ca 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -460,6 +460,8 @@ struct using C0GridDesc_M_N = remove_cvref_t; using C1GridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< BlockSize, @@ -522,8 +524,6 @@ struct std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -541,8 +541,6 @@ struct c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -575,8 +573,12 @@ struct c0_grid_desc_m_n_ = descs[I3]; c1_grid_desc_m_n_ = descs[I4]; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: @@ -592,9 +594,6 @@ struct GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c1_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -618,9 +617,7 @@ struct typename GridwiseGemm:: C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -642,7 +639,7 @@ struct { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -689,14 +686,14 @@ struct if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -723,12 +720,12 @@ struct InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -767,12 +764,12 @@ struct InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -795,9 +792,10 @@ struct return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -851,8 +849,7 @@ struct return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -894,8 +891,6 @@ struct conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -938,8 +933,6 @@ struct conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index fd0941420c..a397b5e2b1 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -548,9 +548,13 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; c0_grid_desc_m_n_ = descs[I3]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: @@ -561,9 +565,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c0_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -605,7 +606,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -649,14 +650,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -684,8 +685,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -723,8 +724,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -745,9 +746,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -801,8 +803,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index b508606a75..707413dfd3 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -417,6 +417,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< BlockSize, @@ -477,8 +479,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -490,8 +490,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W c_grid_desc_m_n_{}, c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -522,16 +520,17 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -544,9 +543,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W typename GridwiseGemm:: CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -568,7 +565,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -631,14 +628,14 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -659,12 +656,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -693,12 +690,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -717,9 +714,10 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -773,8 +771,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -812,8 +809,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -852,8 +847,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 3574f7667e..ece18459a0 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -408,15 +408,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K a_grid_desc_k0_m_k1_ = descs[I0]; b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -450,7 +451,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -469,14 +470,14 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -498,8 +499,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K remove_reference_t, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -529,8 +530,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K remove_reference_t, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -549,9 +550,10 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -605,8 +607,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index c3ebe58865..b1eea0b33f 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -4,7 +4,7 @@ #include #include #include -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "device.hpp" #include "device_conv_fwd.hpp" #include "common_header.hpp" @@ -92,7 +92,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const auto naive_conv3d_fwd = ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk; - float ave_time = launch_and_time_kernel(naive_conv3d_fwd, - nrepeat, + float ave_time = launch_and_time_kernel(stream_config, + naive_conv3d_fwd, dim3(256), dim3(256), 0, @@ -137,9 +137,10 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index ff30a6880d..256d0f81e9 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -259,50 +259,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - struct Block2CTileMapMaker - { - Block2CTileMapMaker(index_t num_batches) : num_batches_(num_batches) {} - - __host__ __device__ constexpr auto - MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_insert_transform(num_batches_), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(num_batches_, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto globalblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - globalblockid_to_g_m00_m01_n00_n01_block_cluster_adaptor); - - return globalblockid_to_m0_n0_block_cluster_adaptor; - } - - private: - index_t num_batches_; - }; - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, InDataType, @@ -345,8 +301,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = - decltype(Block2CTileMapMaker{1}.MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; // Argument struct Argument : public BaseArgument @@ -398,18 +353,20 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); b_batch_stride_ = 0; c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize(); - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = Block2CTileMapMaker{num_subbatches_}.MakeBlock2CTileMap( - c_grid_desc_m_n_, M01, N01); } } @@ -438,7 +395,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { { std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; @@ -457,16 +414,15 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - // todo: grid_size times arg.num_subbatches_ const index_t grid_size = - GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.num_subbatches_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * + arg.num_subbatches_; const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); @@ -487,8 +443,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ OutElementwiseOperation, remove_reference_t, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -522,8 +478,8 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ remove_reference_t, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -547,9 +503,10 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -564,8 +521,7 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..2991526851 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1392 @@ +#pragma once + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_backward_weight.hpp" +#include "convolution_backward_weight_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_bwd_weight.hpp" +#include "gridwise_unary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight +{ + using DeviceOp = + DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + // type convert descs + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * 4; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + template + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using TypeConvertFp32ToBf16Functor = + ck::tensor_operation::element_wise::UnaryTypeConvert; + using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAddFloatBf16Splitk = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector output_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + float ave_time = 0; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + const auto run_conv = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + // run kernel for bf16 with splitk + const auto run_bf16_splitk = [&](const auto& kernel) { + hipGetErrorString(hipMemset( + arg.p_workspace_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(AccDataType))); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + static_cast(arg.p_workspace_), + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + // kernel for type conversion + std::vector filter_dims{static_cast(arg.Conv_K_), + static_cast(arg.Conv_C_)}; + + filter_dims.insert(std::end(filter_dims), + std::begin(arg.filter_spatial_lengths_), + std::end(arg.filter_spatial_lengths_)); + + int tensor_size = + std::accumulate(filter_dims.begin(), filter_dims.end(), 1, std::multiplies{}); + + const index_t type_convert_grid_size = GridwiseUEltwise::CalculateGridSize(tensor_size); + GridDesc_M0 a_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + GridDesc_M0 b_grid_desc_m0_ = + MakeDescriptor_M0<1>({tensor_size}, {1}, type_convert_grid_size, 256); + + if(!GridwiseUEltwise::CheckValidity(a_grid_desc_m0_, b_grid_desc_m0_)) + { + throw std::runtime_error("wrong! GridwiseUnaryElementwise_1D has invalid setting"); + } + + // run kernel for type conversion + void* p_c_grid_tmp_ = static_cast(arg.p_c_grid_); + InDataType* p_c_grid_tmp_bf16_ = static_cast(p_c_grid_tmp_); + const auto run_type_convert = [&](const auto& kernel) { + float elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(type_convert_grid_size), + dim3(256), + 0, + static_cast(arg.p_workspace_), + p_c_grid_tmp_bf16_, + a_grid_desc_m0_, + b_grid_desc_m0_, + TypeConvertFp32ToBf16Functor{}); + return elapsed_time; + }; + + if constexpr(std::is_same::value) + { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + else + { + const auto kernel_type_convert = + kernel_unary_elementwise_1d; + + const auto kernel_conv = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAddFloatBf16Splitk, + ADataType, // TODO: distiguish A/B datatype + AccDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + float elapsed_time = 0; + elapsed_time += run_bf16_splitk(kernel_conv); + elapsed_time += run_type_convert(kernel_type_convert); + return elapsed_time; + } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + } + else + { + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + has_main_loop>; + + return run_conv(kernel); + } + }; + if(has_main_k0_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = + arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] * + sizeof(float); + } + } + return WorkSpaceSize; + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final + { + return GetWorkSpaceSize(*dynamic_cast(p_arg)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp index ff267c6cdf..0517db4415 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -1073,13 +1073,15 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_)) + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], block_2_ctile_map)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - block_2_ctile_map_container_.push_back( - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); + block_2_ctile_map_container_.push_back(block_2_ctile_map); } } } @@ -1129,13 +1131,16 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_)) + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); - block_2_ctile_map_container_.push_back( - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); + block_2_ctile_map_container_.push_back(block_2_ctile_map); } } } @@ -1194,14 +1199,17 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); c_grid_desc_m_n_container_.push_back(descs[I2]); - if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2], M01_, N01_)) + auto block_2_ctile_map = + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_); + + if(GridwiseGemm::CheckValidity( + descs[I0], descs[I1], descs[I2], block_2_ctile_map)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2( descs[I2])); - block_2_ctile_map_container_.push_back( - GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); + block_2_ctile_map_container_.push_back(block_2_ctile_map); } } } @@ -1241,7 +1249,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) @@ -1286,15 +1294,14 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], arg.c_grid_desc_m_n_container_[i], - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_container_[i])) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); } - const index_t grid_size = - GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); @@ -1316,8 +1323,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho true>; ave_time += launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -1349,8 +1356,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho false>; ave_time += launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -1369,9 +1376,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -1412,13 +1420,12 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho } // Gridwise GEMM size - for(int i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) { if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], arg.b_grid_desc_k0_n_k1_container_[i], arg.c_grid_desc_m_n_container_[i], - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_container_[i])) { return false; } @@ -1527,10 +1534,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho << ">"; if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ - + str<< " Filter1x1Stride1Pad0"; } - + return str.str(); } diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index ac62448386..c1ab44a28b 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,4 @@ -#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#pragma once #include #include @@ -8,6 +7,7 @@ #include #include "device.hpp" +#include "device_prop.hpp" #include "device_base.hpp" #include "device_conv_fwd.hpp" #include "convolution_forward_specialization.hpp" @@ -607,6 +607,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, @@ -664,8 +666,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::vector conv_filter_dilations, std::vector input_left_pads, std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) @@ -677,8 +677,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K c_grid_desc_m_n_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, @@ -706,14 +704,15 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K b_grid_desc_k0_n_k1_ = descs[I1]; c_grid_desc_m_n_ = descs[I2]; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -726,9 +725,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K CGridDesc_M_N c_grid_desc_m_n_; typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; + Block2CTileMap block_2_ctile_map_; InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; @@ -747,7 +744,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -766,14 +763,14 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -792,11 +789,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -823,11 +820,11 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, - remove_reference_t, + Block2CTileMap, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -846,9 +843,10 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return ave_time; } - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -860,18 +858,33 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - // Input tensors can't be bigger than 2GB each. - constexpr std::size_t GB2 = 2 * 1e9; + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } - if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2) - { - return false; - } - if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2) - { - return false; - } - if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2) + // Input tensors can't be bigger than 2GB each. + constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); + + if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() * sizeof(ADataType) > GB2 || + arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() * sizeof(BDataType) > GB2 || + arg.c_grid_desc_m_n_.GetElementSpaceSize() * sizeof(CDataType) > GB2) { return false; } @@ -921,8 +934,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } bool IsSupportedArgument(const BaseArgument* p_arg) override @@ -960,8 +972,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op}; @@ -1000,8 +1010,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K conv_filter_dilations, input_left_pads, input_right_pads, - 1, - 1, in_element_op, wei_element_op, out_element_op); @@ -1017,8 +1025,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K auto str = std::stringstream(); // clang-format off - str << "DeviceConv" << std::to_string(NumDimSpatial) - << "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + str << "DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" << "<" << BlockSize << ", " << MPerBlock << ", " @@ -1035,4 +1042,3 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000..b29eb37898 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -0,0 +1,813 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_reduce.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmBiasAddReduce_Xdl_CShuffle + : public DeviceGemmBiasAddReduce +{ + using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeDGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); + using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + C0DataType, + C1DataType, + ReduceAccDataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + C1ElementwiseOperation, + DxsReduceOperation, + DxsInElementwiseOperation, + DxsReduceAccElementwiseOperation, + InMemoryDataOperationEnum::Set, + DGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + DGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const C0DataType* p_c0_grid, + const C1DataType* p_c1_grid, + DPtrsGlobal p_ds_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + C1ElementwiseOperation c1_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_{p_c0_grid}, + p_c1_grid_{p_c1_grid}, + p_ds_grid_{p_ds_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, + c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, + d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + c1_element_op_{c1_element_op}, + dxs_in_element_op_{dxs_in_element_op}, + dxs_out_element_op_{dxs_out_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + d_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const C0DataType* p_c0_grid_; + const C1DataType* p_c1_grid_; + DPtrsGlobal p_ds_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + DGridDesc_M d_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + C1ElementwiseOperation c1_element_op_; + DxsInElementwiseOperation dxs_in_element_op_; + DxsReduceAccElementwiseOperation dxs_out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + C1DataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + C1ElementwiseOperation, + DxsInElementwiseOperation, + DxsReduceAccElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.c1_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + C1DataType, + DPtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + C1ElementwiseOperation, + DxsInElementwiseOperation, + DxsReduceAccElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.c1_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const C0DataType* p_c0, + const C1DataType* p_c1, + DPtrsGlobal p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + C1ElementwiseOperation c1_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0, + p_c1, + p_dxs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + dxs_in_element_op, + dxs_out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + void* p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + C1ElementwiseOperation c1_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + index_t /* KBatch */ = 1) override + { + DPtrsGlobal dxs_tuple = *(static_cast(p_dxs)); + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0), + static_cast(p_c1), + dxs_tuple, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + dxs_in_element_op, + dxs_out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp new file mode 100644 index 0000000000..5ccf1934fe --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -0,0 +1,586 @@ +#pragma once + +#include +#include + +#include "device.hpp" +#include "device_prop.hpp" +#include "device_base.hpp" +#include "device_gemm.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_operation.hpp" +#include "gridwise_gemm_dl_v1r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl + : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000..847000f7b7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmMultipleD : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmMultipleDPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..2de5897311 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,750 @@ +#pragma once + +#include +#include + +#include "device.hpp" +#include "device_gemm_multiple_d.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "gemm_specialization.hpp" +#include "device_prop.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// input : A[M, K], or A[K, N] +// input : B[K, N], or A[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + } + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cv_t; + + return static_cast(nullptr); + }, + Number{}); + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index 59f4ecc617..d7a10bb6a9 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -9,25 +9,27 @@ namespace device { template + typename DxsInElementwiseOperation, + typename DxsReduceAccElementwiseOperation> struct DeviceGemmReduce : public BaseOperator { - virtual std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - void* p_d0, - void* p_d1, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideB, - ck::index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - D1ElementwiseOperation d1_element_op, - ck::index_t BatchCount = 1) = 0; + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_dxs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + ck::index_t BatchCount = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; @@ -35,11 +37,60 @@ struct DeviceGemmReduce : public BaseOperator template + typename DxsInElementwiseOperation, + typename DxsReduceAccElementwiseOperation> using DeviceGemmReducePtr = std::unique_ptr>; + DxsInElementwiseOperation, + DxsReduceAccElementwiseOperation>>; + +template +struct DeviceGemmBiasAddReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + void* p_dxs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + C1ElementwiseOperation c1_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + ck::index_t BatchCount = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasAddReducePtr = + std::unique_ptr>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 1a3fbdf956..989883bd39 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -14,6 +14,9 @@ namespace ck { namespace tensor_operation { namespace device { +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. template + index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce + DxsInElementwiseOperation, + DxsReduceAccElementwiseOperation> { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; @@ -376,15 +382,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; // Argument struct Argument : public BaseArgument @@ -430,8 +437,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.p_d1_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d1_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); } else { @@ -574,11 +588,12 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce; - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_d0_grid_, - arg.p_d1_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.d1_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.d_grid_desc_mblock_mperblock_, - arg.block_2_ctile_map_); + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_ds_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.dxs_in_element_op_, + arg.dxs_out_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); } - return 0; + return elapsed_time; } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -624,8 +642,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce MakeArgumentPointer(const void* p_a, - const void* p_b, - void* p_c, - void* p_d0, - void* p_d1, - index_t MRaw, - index_t NRaw, - index_t KRaw, - index_t StrideA, - index_t StrideB, - index_t StrideC, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - D1ElementwiseOperation d1_element_op, - index_t /* KBatch */ = 1) override + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_dxs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + DxsInElementwiseOperation dxs_in_element_op, + DxsReduceAccElementwiseOperation dxs_out_element_op, + index_t /* KBatch */ = 1) override { + DPtrsGlobal dxs_tuple = *(static_cast(p_dxs)); return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), - static_cast(p_d0), - static_cast(p_d1), + dxs_tuple, MRaw, NRaw, KRaw, @@ -701,7 +722,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce #include #include "device.hpp" +#include "device_prop.hpp" #include "device_base.hpp" #include "device_gemm.hpp" #include "common_header.hpp" @@ -257,14 +257,16 @@ struct DeviceGemmXdl b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -290,7 +292,7 @@ struct DeviceGemmXdl { using Argument = DeviceGemmXdl::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -310,14 +312,14 @@ struct DeviceGemmXdl if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -339,8 +341,8 @@ struct DeviceGemmXdl remove_reference_t, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -370,8 +372,8 @@ struct DeviceGemmXdl remove_reference_t, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, + ave_time = launch_and_time_kernel(stream_config, + kernel, dim3(grid_size), dim3(BlockSize), 0, @@ -391,9 +393,10 @@ struct DeviceGemmXdl } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -405,11 +408,31 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic @@ -513,4 +536,3 @@ struct DeviceGemmXdl } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp index 4010965312..1db69dd462 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp @@ -218,8 +218,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d c_grid_desc_m_n_ = DeviceGemmXdl_C_Shuffle_Bias_2d::MakeCGridDescriptor_M_N(M, N, StrideC); - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: @@ -230,9 +235,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -264,7 +266,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d { using Argument = DeviceGemmXdl_C_Shuffle_Bias_2d::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -285,14 +287,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -320,8 +322,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -359,8 +361,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -382,9 +384,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -399,8 +402,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp index c65ff6022a..b465f8e4ae 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp @@ -227,8 +227,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation c_grid_desc_m_n_ = descs[I2]; c0_grid_desc_m_n_ = descs[I3]; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: @@ -239,9 +244,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c0_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -273,7 +275,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -294,14 +296,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -329,8 +331,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -368,8 +370,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -391,9 +393,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -408,8 +411,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp index 4a478c995d..7a2e1886d3 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp @@ -256,8 +256,13 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add c0_grid_desc_m_n_ = descs[I3]; c1_grid_desc_m_n_ = descs[I4]; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = GridwiseGemm:: @@ -273,9 +278,6 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add GridwiseGemm:: MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( c1_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } } @@ -312,7 +314,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) @@ -336,14 +338,14 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); @@ -374,8 +376,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add true>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -418,8 +420,8 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add false>; ave_time = launch_and_time_kernel( + stream_config, kernel, - nrepeat, dim3(grid_size), dim3(BlockSize), 0, @@ -443,9 +445,10 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -460,8 +463,7 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp index 440519537e..a74ee81679 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp @@ -9,11 +9,15 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdl_cshuffle_v1.hpp" #include "tensor_operation/gpu/device/gemm_specialization.hpp" +#include "device_prop.hpp" namespace ck { namespace tensor_operation { namespace device { +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. template + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGemm_Xdl_CShuffle : public DeviceGemm { @@ -375,7 +380,8 @@ struct DeviceGemm_Xdl_CShuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock>; + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; // Argument struct Argument : public BaseArgument @@ -399,19 +405,19 @@ struct DeviceGemm_Xdl_CShuffle b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { - if(GridwiseGemm::CheckValidity( - a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n_); - - block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); } } @@ -435,7 +441,7 @@ struct DeviceGemm_Xdl_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -454,13 +460,16 @@ struct DeviceGemm_Xdl_CShuffle } #endif - if(!GridwiseGemm::CheckValidity( - arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); @@ -482,42 +491,22 @@ struct DeviceGemm_Xdl_CShuffle typename GridwiseGemm::DefaultBlock2CTileMap, true>; - if(nrepeat == 0) - { - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } - else - { - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } else { @@ -533,52 +522,32 @@ struct DeviceGemm_Xdl_CShuffle typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::DefaultBlock2CTileMap, false>; - - if(nrepeat == 0) - { - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } - else - { - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } return ave_time; } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -590,8 +559,15 @@ struct DeviceGemm_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - return GridwiseGemm::CheckValidity( - arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp index db6c884739..d9fc8f7a8a 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp @@ -12,6 +12,7 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp" #include "gemm_specialization.hpp" +#include "device_prop.hpp" #ifndef CK_RUN_KERNEL_AND_TIME #define CK_RUN_KERNEL_AND_TIME 1 @@ -332,17 +333,16 @@ struct DeviceGemmXdlSplitK K, N, StrideB, k_batch_, KPad); c_grid_desc_m_n_ = DeviceGemmXdlSplitK::MakeCGridDescriptor_M_N(M, N, StrideC); + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_m_n_, - M01_, - N01_)) + block_2_ctile_map_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); } } @@ -385,21 +385,24 @@ struct DeviceGemmXdlSplitK std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, int nrepeat = 1) + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + ShowInfo(arg); + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); @@ -408,50 +411,30 @@ struct DeviceGemmXdlSplitK float ave_time = 0; const auto Run = [&](const auto& kernel) { - if(nrepeat > 0) - { - ShowInfo(arg); - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + // FIXME: this should be moved outside of DeviceOp + hipGetErrorString( + hipMemset(arg.p_c_grid_, + 0, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * + sizeof(CDataType))); - if(kbatch > 1 || nrepeat <= 0) - { - hipGetErrorString( - hipMemset(arg.p_c_grid_, - 0, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * - sizeof(CDataType))); - - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; + if(has_main_k0_block_loop) { if(kbatch == 1) @@ -531,9 +514,10 @@ struct DeviceGemmXdlSplitK } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -545,11 +529,15 @@ struct DeviceGemmXdlSplitK static bool IsSupportedArgument(const Argument& arg) { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp index 9de5361ab6..ad424d91d9 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp @@ -292,8 +292,7 @@ struct DeviceGemmXdlSplitKCShuffle using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - using Block2CTileMap = - decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; // Argument struct Argument : public BaseArgument @@ -338,17 +337,16 @@ struct DeviceGemmXdlSplitKCShuffle K, N, StrideB, k_batch_, KPad); c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_m_n_, - M01_, - N01_)) + block_2_ctile_map_)) { c_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); } } @@ -391,21 +389,24 @@ struct DeviceGemmXdlSplitKCShuffle std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - float Run(const Argument& arg, int nrepeat = 1) + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { + ShowInfo(arg); + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) + arg.block_2_ctile_map_)) { throw std::runtime_error( "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); } - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); @@ -414,51 +415,29 @@ struct DeviceGemmXdlSplitKCShuffle float ave_time = 0; const auto Run = [&](const auto& kernel) { - if(nrepeat > 0) - { - ShowInfo(arg); - ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + hipGetErrorString(hipMemset( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType))); - if(kbatch > 1 || nrepeat <= 0) - { - hipGetErrorString(hipMemset( - arg.p_c_grid_, - 0, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); - - launch_kernel(kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); - } + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); }; + if(has_main_k0_block_loop) { if(kbatch == 1) @@ -542,9 +521,10 @@ struct DeviceGemmXdlSplitKCShuffle } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -559,8 +539,7 @@ struct DeviceGemmXdlSplitKCShuffle return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); + arg.block_2_ctile_map_); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp index b9ad39578d..6dfa448fa8 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp @@ -24,57 +24,33 @@ template + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdlops_v2r3( - const StaticallyIndexedArray gemm_descs, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); -#if 1 - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - if(block_id >= gemm_descs[i].BlockStart_ && block_id < gemm_descs[i].BlockEnd_ && - i < group_count) - { - auto group_id = i; - - GridwiseGemm::template Run( - gemm_descs[group_id].a_ptr, - gemm_descs[group_id].b_ptr, - gemm_descs[group_id].c_ptr, - p_shared, - gemm_descs[group_id].a_grid_desc_k0_m_k1_, - gemm_descs[group_id].b_grid_desc_k0_n_k1_, - gemm_descs[group_id].c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - a_element_op, - b_element_op, - c_element_op, - gemm_descs[group_id].grouped_gemm_block_2_ctile_map_); - } - }); -#else - const auto gemm_desc_ptr = reinterpret_cast(&gemm_descs); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); index_t group_id = 0; - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - group_id = (block_id >= gemm_descs[i].BlockStart && block_id < gemm_descs[i].BlockEnd && - i < group_count) - ? i - : group_id; - }); - - const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart; + for(index_t i = 0; i < group_count; i++) + { + group_id = + (block_id >= gemm_desc_ptr[i].BlockStart_ && block_id < gemm_desc_ptr[i].BlockEnd_) + ? i + : group_id; + } GridwiseGemm::template Run( gemm_desc_ptr[group_id].a_ptr, @@ -87,11 +63,9 @@ __global__ void a_element_op, b_element_op, c_element_op, - gemm_desc_ptr[group_id].block_2_ctile_map_, - block_id_grp); -#endif + gemm_desc_ptr[group_id].grouped_gemm_block_2_ctile_map_); #else - ignore = gemm_descs; + ignore = gemm_descs_const; ignore = group_count; ignore = a_element_op; ignore = b_element_op; @@ -307,6 +281,11 @@ struct DeviceGroupedGemmXdl struct GroupedGemmBlock2CTileMap { + using UnderlyingBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + static_assert( + std::is_same::value, + "Wrong! Should be the same type name"); GroupedGemmBlock2CTileMap() { block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); @@ -329,7 +308,18 @@ struct DeviceGroupedGemmXdl make_multi_index(idx_top[I0] - BlockStart_)); } - private: + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_2_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; ck::index_t BlockStart_; }; @@ -372,17 +362,20 @@ struct DeviceGroupedGemmXdl { grid_size_ = 0; - group_count_ = static_cast(gemm_shapes.size()); + p_workspace_ = nullptr; - if(!(group_count_ == p_a.size() && group_count_ == p_b.size() && - group_count_ == p_c.size())) + group_count_ = ck::type_convert(gemm_shapes.size()); + + if(!(group_count_ == ck::type_convert(p_a.size()) && + group_count_ == ck::type_convert(p_b.size()) && + group_count_ == ck::type_convert(p_c.size()))) { throw std::runtime_error("wrong! group_count_ != P_a/b/c.size"); } gemm_desc_kernel_arg_.reserve(group_count_); - for(index_t i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { const index_t M = gemm_shapes[i].M; const index_t N = gemm_shapes[i].N; @@ -399,22 +392,26 @@ struct DeviceGroupedGemmXdl const auto c_grid_desc_m_n_ = DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); - const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_); + const index_t grid_size_grp = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) + .block_2_ctile_map_.CalculateGridSize(c_grid_desc_m_n_); const index_t BlockStart = grid_size_; const index_t BlockEnd = grid_size_ + grid_size_grp; grid_size_ += grid_size_grp; - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + const auto grouped_gemm_block_2_ctile_map_ = + GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + grouped_gemm_block_2_ctile_map_)) { const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - const auto grouped_gemm_block_2_ctile_map_ = - GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, BlockStart); - gemm_desc_kernel_arg_.push_back( GemmDescKernelArg{a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, @@ -448,51 +445,51 @@ struct DeviceGroupedGemmXdl { using Argument = DeviceGroupedGemmXdl::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - StaticallyIndexedArray gemm_desc_kernel_args; - bool has_main_k_block_loop = true; - static_for<0, MaxGroupCount, 1>{}([&](auto i) { - if(i < arg.gemm_desc_kernel_arg_.size()) + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; + + std::cout << ", arg.c_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + + if(!GridwiseGemm::CheckValidity( + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].c_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].grouped_gemm_block_2_ctile_map_)) { - gemm_desc_kernel_args(i) = arg.gemm_desc_kernel_arg_[i]; - - std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"; - - std::cout << ", arg.b_grid_desc_k0_n_k1_{" - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"; - - std::cout << ", arg.c_grid_desc_m_n_{ " - << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I0) << ", " - << gemm_desc_kernel_args[i].c_grid_desc_m_n_.GetLength(I1) << "}" - << std::endl; - - if(!GridwiseGemm::CheckValidity(gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_, - gemm_desc_kernel_args[i].b_grid_desc_k0_n_k1_, - gemm_desc_kernel_args[i].c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); - } - - const auto K = gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I0) * - gemm_desc_kernel_args[i].a_grid_desc_k0_m_k1_.GetLength(I2); - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) - { - throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); - } + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); } - }); + + const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) * + arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString( + hipMemcpy(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmDescKernelArg), + hipMemcpyHostToDevice)); float ave_time = 0; @@ -502,23 +499,23 @@ struct DeviceGroupedGemmXdl kernel_grouped_gemm_xdlops_v2r3, + GemmDescKernelArg, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - true, - MaxGroupCount>; + true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - gemm_desc_kernel_args, - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } else { @@ -526,32 +523,33 @@ struct DeviceGroupedGemmXdl kernel_grouped_gemm_xdlops_v2r3, + GemmDescKernelArg, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - false, - MaxGroupCount>; + false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - gemm_desc_kernel_args, - arg.gemm_desc_kernel_arg_.size(), - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } return ave_time; } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; @@ -563,7 +561,7 @@ struct DeviceGroupedGemmXdl static bool IsSupportedArgument(const Argument& arg) { - if(arg.gemm_desc_kernel_arg_.size() != arg.group_count_) + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) return false; else return true; @@ -630,6 +628,11 @@ struct DeviceGroupedGemmXdl return str.str(); } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmDescKernelArg); + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp index 651d31ae2f..41fb11b7de 100644 --- a/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp +++ b/include/ck/tensor_operation/gpu/device/device_pool2d_fwd_nhwc_nhwc.hpp @@ -17,7 +17,7 @@ template ::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; - - static constexpr bool BetaIsZero = true; + typename reduce_unary_operator::AccElementwiseOperation; static constexpr index_t InSrcOutDstVectorDim = 0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is @@ -180,13 +177,10 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd invariant_lowest_length_ = C; reduce_lowest_length_ = window_spatial_lengths[1]; - // TODO: is this correct? - if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG) - { - ck::index_t divider = window_spatial_lengths[0] * window_spatial_lengths[1]; - in_element_op_ = InElementwiseOperation{divider}; - acc_element_op_ = AccElementwiseOperation{divider}; - } + int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); } const InDataType* p_in_dev_; @@ -204,30 +198,30 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C : public DevicePool2dFwd struct Invoker : public BaseInvoker { - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; const auto kernel = kernel_reduce_threadwise(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; diff --git a/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce.hpp index 50fa64dab8..6f367a8747 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -16,35 +16,18 @@ namespace device { template struct DeviceReduce : public BaseOperator { - virtual long_index_t GetWorkspaceSizeInBytes(const std::vector inLengths, - const std::vector reduceDims) - { - (void)inLengths; - (void)reduceDims; - - return (0); - }; - - virtual bool HasFurtherCall() { return (false); }; - - virtual std::vector GetWorkspace2dLengths(const BaseArgument* argPtr) - { - (void)argPtr; - return (std::vector{0, 0}); - }; - virtual std::unique_ptr - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, const std::vector reduceDims, float alpha, float beta, const void* in_dev, + const void* in_index_dev, void* out_dev, - void* out_indices_dev, - void* workspace_dev, + void* out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) = 0; diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp deleted file mode 100644 index 4f17989b53..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_reduce_blockwise.hpp +++ /dev/null @@ -1,373 +0,0 @@ -#ifndef DEVICE_REDUCE_BLOCKWISE_HPP -#define DEVICE_REDUCE_BLOCKWISE_HPP - -#include -#include -#include "device.hpp" -#include "device_reduce.hpp" -#include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_blockwise.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceReduceBlockWise : public DeviceReduce -{ - static_assert(Rank <= 6, "Bigger Rank size is not supported!"); - static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, - "Invalid thread cluster size assignments!"); - - static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || - (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && - (MThreadSliceSize % OutDstVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - using IndexDataType = int32_t; - - static constexpr bool BetaIsZero = NeedIndices; - - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; - - static constexpr index_t numSrcDim = Rank; - static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); - - static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides) - { - const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); - - const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - - const auto in_grid_desc_m_k = [&]() { - if constexpr(reduceAllDim) - { - const auto one_dim_inDesc = transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - return transform_tensor_descriptor(one_dim_inDesc, - make_tuple(make_unmerge_transform(make_tuple( - 1, one_dim_inDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - } - else - { - using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; - using ReduceDims = typename arithmetic_sequence_gen::type; - - const auto reduceDimLengths = - make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); - - return transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(reduceDimLengths)), - make_tuple(InvariantDims{}, ReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - }(); - - const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); - const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const auto inPad_M = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = - math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; - - auto in_grid_desc_m_k_padded = transform_tensor_descriptor( - in_grid_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, inPad_M), - make_right_pad_transform(reduceLength, inPad_K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (in_grid_desc_m_k_padded); - }; - - static auto MakeDst1dDescriptor(const std::vector& outLengths, - const std::vector& outStrides) - { - const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); - - auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto out_grid_desc_m = transform_tensor_descriptor( - outDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); - - const auto inPad = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - - auto out_grid_desc_m_padded = transform_tensor_descriptor( - out_grid_desc_m, - make_tuple(make_right_pad_transform(invariantLength, inPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return (out_grid_desc_m_padded); - }; - - struct Argument : public BaseArgument - { - Argument(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, - const std::vector reduceDims, - float alpha, - float beta, - const InDataType* in_dev, - OutDataType* out_dev, - IndexDataType* out_indices_dev, - AccDataType* workspace_dev, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op) - : outLengths_{outLengths}, - outStrides_{outStrides}, - in_dev_{in_dev}, - out_dev_{out_dev}, - out_indices_dev_{out_indices_dev}, - in_elementwise_op_{in_elementwise_op}, - acc_elementwise_op_{acc_elementwise_op} - { - (void)workspace_dev; - - inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); - inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); - - alpha_ = type_convert(alpha); - beta_ = type_convert(beta); - - std::tie(invariant_total_length, reduce_total_length) = - get_2d_lengths(inLengths_); - - if constexpr(NumInvariantDim == 0) - invariant_lowest_length = 1; - else - invariant_lowest_length = inLengths_[NumInvariantDim - 1]; - - reduce_lowest_length = inLengths_[Rank - 1]; - - gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize; - } - - std::vector inLengths_; - std::vector inStrides_; - std::vector outLengths_; - std::vector outStrides_; - - AccDataType alpha_; - AccDataType beta_; - - const InDataType* in_dev_; - OutDataType* out_dev_; - IndexDataType* out_indices_dev_; - - InElementwiseOperation in_elementwise_op_; - AccElementwiseOperation acc_elementwise_op_; - - int invariant_lowest_length; - int reduce_lowest_length; - size_t invariant_total_length; - size_t reduce_total_length; - - size_t gridSize; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, int nrepeat = 1) - { - const auto in_grid_desc_m_k = - DeviceReduceBlockWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); - const auto out_grid_desc_m = - DeviceReduceBlockWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); - using InGridDesc_M_K = decltype(in_grid_desc_m_k); - using OutGridDesc_M = decltype(out_grid_desc_m); - - using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise; - - float avg_time = 0; - - const auto kernel = kernel_reduce_blockwise; - - avg_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(arg.gridSize), - dim3(BlockSize), - 0, - in_grid_desc_m_k, - out_grid_desc_m, - arg.in_elementwise_op_, - arg.acc_elementwise_op_, - arg.alpha_, - arg.in_dev_, - arg.beta_, - arg.out_dev_, - nullptr, - arg.out_indices_dev_); - - return (avg_time); - }; - - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - }; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if constexpr(InSrcVectorDim == 0) - { - if constexpr(NumInvariantDim == 0) - { - return (false); - } - else - { - if(pArg->inStrides_[NumInvariantDim - 1] != 1) - return (false); - - if(pArg->invariant_lowest_length % InSrcVectorSize != 0) - return (false); - }; - } - else - { - if(pArg->inStrides_[Rank - 1] != 1) - return (false); - - if(pArg->reduce_lowest_length % InSrcVectorSize != 0) - return (false); - }; - - // To improve - if(pArg->invariant_lowest_length % OutDstVectorSize != 0) - return (false); - - // cases with very small reduce_total_length should be handled by the ThreadWise method - if(pArg->reduce_total_length / KThreadSliceSize < 2) - return (false); - - return (true); - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, - const std::vector reduceDims, - float alpha, - float beta, - const void* in_dev, - void* out_dev, - void* out_indices_dev, - void* workspace_dev, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op) override - { - return std::make_unique(inLengths, - inStrides, - outLengths, - outStrides, - reduceDims, - alpha, - beta, - static_cast(in_dev), - static_cast(out_dev), - static_cast(out_indices_dev), - static_cast(workspace_dev), - in_elementwise_op, - acc_elementwise_op); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceReduceBlockWise<" << BlockSize << ","; - str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; - str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; - str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp deleted file mode 100644 index d3b1b4b5c3..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_reduce_blockwise_second_call.hpp +++ /dev/null @@ -1,327 +0,0 @@ -#ifndef DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP -#define DEVICE_REDUCE_BLOCKWISE_SECOND_CALL_HPP - -#include -#include -#include "device.hpp" -#include "device_reduce.hpp" -#include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_blockwise.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceReduceBlockWiseSecondCall - : public DeviceReduce -{ - static_assert(Rank <= 6, "Bigger Rank size is not supported!"); - static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, - "Invalid thread cluster size assignments!"); - - static_assert((InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0) && - (MThreadSliceSize % OutDstVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - using IndexDataType = int32_t; - - static constexpr bool BetaIsZero = NeedIndices; - - static_assert( - std::is_same::value, - "InDataType and AccDataType should be the same to use DEviceReduceBlockWiseSecondCall!"); - - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; - - static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; - - static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides) - { - const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<2>{}); - const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<2>{}); - - const auto in_grid_desc_m_k = - make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - - const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); - const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const auto inPad_M = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = - math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; - - auto in_grid_desc_m_k_padded = transform_tensor_descriptor( - in_grid_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, inPad_M), - make_right_pad_transform(reduceLength, inPad_K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (in_grid_desc_m_k_padded); - }; - - static auto MakeDst1dDescriptor(const std::vector& outLengths, - const std::vector& outStrides) - { - const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); - - auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); - - auto out_grid_desc_m = transform_tensor_descriptor( - outDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); - - const auto outPad = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - - auto out_grid_desc_m_padded = transform_tensor_descriptor( - out_grid_desc_m, - make_tuple(make_right_pad_transform(invariantLength, outPad)), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0>{})); - return (out_grid_desc_m_padded); - }; - - struct Argument : public BaseArgument - { - Argument(const std::vector& inLengths, - const std::vector& inStrides, - const std::vector& outLengths, - const std::vector& outStrides, - float alpha, - float beta, - const InDataType* in_dev, - OutDataType* out_dev, - IndexDataType* out_indices_dev, - AccDataType* workspace_dev, - const InElementwiseOperation& in_elementwise_op, - const AccElementwiseOperation& acc_elementwise_op) - : inLengths_(inLengths), - inStrides_(inStrides), - outLengths_(outLengths), - outStrides_(outStrides), - in_dev_{in_dev}, - out_dev_{out_dev}, - out_indices_dev_{out_indices_dev}, - in_elementwise_op_(in_elementwise_op), - acc_elementwise_op_(acc_elementwise_op) - { - alpha_ = type_convert(alpha); - beta_ = type_convert(beta); - - invariant_total_length = inLengths[0]; - reduce_total_length = inLengths[1]; - - invariant_lowest_length = inLengths[0]; - reduce_lowest_length = inLengths[1]; - - gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize; - - size_t ws_buf2_bytes_offset = math::integer_least_multiple( - invariant_total_length * reduce_total_length * sizeof(AccDataType), 64); - - if constexpr(NeedIndices) - workspace_indices_dev_ = reinterpret_cast( - reinterpret_cast(workspace_dev) + ws_buf2_bytes_offset); - else - workspace_indices_dev_ = nullptr; - } - - std::vector inLengths_; - std::vector inStrides_; - std::vector outLengths_; - std::vector outStrides_; - - AccDataType alpha_; - AccDataType beta_; - - const InDataType* in_dev_; - OutDataType* out_dev_; - IndexDataType* out_indices_dev_; - IndexDataType* workspace_indices_dev_; - - InElementwiseOperation in_elementwise_op_; - AccElementwiseOperation acc_elementwise_op_; - - int invariant_lowest_length; - int reduce_lowest_length; - size_t invariant_total_length; - size_t reduce_total_length; - - size_t gridSize; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, int nrepeat = 1) - { - const auto in_grid_desc_m_k = DeviceReduceBlockWiseSecondCall::MakeSrc2dDescriptor( - arg.inLengths_, arg.inStrides_); - const auto out_grid_desc_m = DeviceReduceBlockWiseSecondCall::MakeDst1dDescriptor( - arg.outLengths_, arg.outStrides_); - using InGridDesc_M_K = decltype(in_grid_desc_m_k); - using OutGridDesc_M = decltype(out_grid_desc_m); - - using GridwiseReduce = GridwiseReduction_mk_to_m_blockwise; - - float avg_time = 0; - - const auto kernel = kernel_reduce_blockwise_second_call; - - avg_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(arg.gridSize), - dim3(BlockSize), - 0, - in_grid_desc_m_k, - out_grid_desc_m, - arg.in_elementwise_op_, - arg.acc_elementwise_op_, - arg.alpha_, - arg.in_dev_, - arg.beta_, - arg.out_dev_, - arg.workspace_indices_dev_, - arg.out_indices_dev_); - - return (avg_time); - }; - - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - }; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if constexpr(InSrcVectorDim == 0) - return (false); - - if(pArg->reduce_lowest_length % InSrcVectorSize != 0) - return (false); - - // To improve - if(pArg->invariant_lowest_length % OutDstVectorSize != 0) - return (false); - - // cases with very small reduce_total_length should be handled by the ThreadWise method - if(pArg->reduce_total_length / KThreadSliceSize < 2) - return (false); - - return (true); - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, - const std::vector reduceDims, - float alpha, - float beta, - const void* in_dev, - void* out_dev, - void* out_indices_dev, - void* workspace_dev, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op) override - { - (void)reduceDims; - - return std::make_unique(inLengths, - inStrides, - outLengths, - outStrides, - alpha, - beta, - static_cast(in_dev), - static_cast(out_dev), - static_cast(out_indices_dev), - static_cast(workspace_dev), - in_elementwise_op, - acc_elementwise_op); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceReduceBlockWiseSecondCall<" << BlockSize << ","; - str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; - str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; - str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp index 038c754722..f68a392821 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce_common.hpp @@ -14,13 +14,13 @@ namespace device { // here, inLengths[] is already shuffled so that lengths of invariant dims are included before those // of reduce dims -template -std::pair get_2d_lengths(const std::vector& inLengths) +template +std::pair get_2d_lengths(const std::vector& inLengths) { static_assert(Rank <= 6, "bigger Rank size not supported!"); - size_t invariant_total_length = 1; - size_t reduce_total_length = 1; + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; constexpr int NumInvariantDim = Rank - NumReduceDim; @@ -35,13 +35,13 @@ std::pair get_2d_lengths(const std::vector& inLengths) // helper functions using variadic template arguments template -auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) +auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) { return make_tuple(static_cast(lengths[Ns])...); }; template -static auto make_tuple_from_array(const std::vector& lengths, Number) +auto make_tuple_from_array(const std::vector& lengths, Number) { static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); @@ -51,10 +51,10 @@ static auto make_tuple_from_array(const std::vector& lengths, Number -std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, - const std::vector& reduceDims) +std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, + const std::vector& reduceDims) { - std::vector newLengthsStrides; + std::vector newLengthsStrides; assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock_atomic_add.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp similarity index 52% rename from include/ck/tensor_operation/gpu/device/device_reduce_multiblock_atomic_add.hpp rename to include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp index 889c366875..99e79e3a1a 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock_atomic_add.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp @@ -1,5 +1,5 @@ -#ifndef DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP -#define DEVICE_REDUCE_MULTIBLOCK_ATOMIC_ADD_HPP +#ifndef DEVICE_REDUCE_MULTIBLOCK_HPP +#define DEVICE_REDUCE_MULTIBLOCK_HPP #include #include @@ -7,8 +7,9 @@ #include "device_base.hpp" #include "device_reduce.hpp" #include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_multiblock_atomic_add.hpp" +#include "gridwise_2d_reduction_multiblock.hpp" #include "gridwise_set_buffer_value.hpp" +#include "reduction_operator.hpp" namespace ck { namespace tensor_operation { @@ -22,8 +23,10 @@ template -struct DeviceReduceMultiBlockAtomicAdd - : public DeviceReduce +struct DeviceReduceMultiBlock : public DeviceReduce { static_assert(Rank <= 6, "Bigger Rank size is not supported!"); static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, @@ -46,26 +48,37 @@ struct DeviceReduceMultiBlockAtomicAdd using IndexDataType = int32_t; + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t numSrcDim = Rank; static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; static constexpr bool reduceAllDim = (NumInvariantDim == 0); - static constexpr bool support_AtomicAdd = - std::is_same::value || std::is_same::value; + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); - static_assert(!NeedIndices && support_AtomicAdd, - "MultiBlockAtomicAdd method can only be used with non-indiced operation and when " - "having float/double output type!"); + static_assert(ck::reduce::InMemoryDataOperatonSupportedOnDataType::value, + "The OutDataType must support the specified OutMemoryDataOperation!"); - static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static_assert(!use_multiblock || (use_multiblock && !OutputIndex), + "MultiBlock reduction can only be used when outputing index is not required"); - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides, + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, int blkGroupSize, - int kBlockTileIterations) + int numBlockTileIteration) { const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -109,7 +122,7 @@ struct DeviceReduceMultiBlockAtomicAdd const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; const auto inPad_M = math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; @@ -124,8 +137,8 @@ struct DeviceReduceMultiBlockAtomicAdd return (in_grid_desc_m_k_padded); }; - static auto MakeDst1dDescriptor(const std::vector& outLengths, - const std::vector& outStrides) + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) { const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); @@ -151,31 +164,56 @@ struct DeviceReduceMultiBlockAtomicAdd return (out_grid_desc_m_padded); }; + static auto MakeDst1dDescriptorForBufferSet(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); + const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + struct Argument : public BaseArgument { - Argument(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, const std::vector reduceDims, float alpha, float beta, const InDataType* in_dev, + const IndexDataType* in_index_dev, OutDataType* out_dev, - IndexDataType* out_indices_dev, - AccDataType* workspace_dev, + IndexDataType* out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) : outLengths_{outLengths}, outStrides_{outStrides}, in_dev_{in_dev}, + in_index_dev_{in_index_dev}, out_dev_{out_dev}, + out_index_dev_{out_index_dev}, in_elementwise_op_{in_elementwise_op}, acc_elementwise_op_{acc_elementwise_op} { - (void)out_indices_dev; - (void)workspace_dev; - inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); @@ -192,24 +230,35 @@ struct DeviceReduceMultiBlockAtomicAdd reduce_lowest_length = inLengths_[Rank - 1]; - int iterations = 1; - while(true) + if constexpr(use_multiblock) { - int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - // we want the blkGroupSize be not more than 128 - if(testBlkGroupSize <= 128) - break; + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); - iterations++; + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; }; - blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - - kBlockTileIterations = iterations; - gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / M_BlockTileSize * blkGroupSize; @@ -217,27 +266,29 @@ struct DeviceReduceMultiBlockAtomicAdd math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; } - std::vector inLengths_; - std::vector inStrides_; - std::vector outLengths_; - std::vector outStrides_; + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; AccDataType alpha_; AccDataType beta_; const InDataType* in_dev_; + const IndexDataType* in_index_dev_; OutDataType* out_dev_; + IndexDataType* out_index_dev_; InElementwiseOperation in_elementwise_op_; AccElementwiseOperation acc_elementwise_op_; - int invariant_lowest_length; - int reduce_lowest_length; - size_t invariant_total_length; - size_t reduce_total_length; + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; - index_t blkGroupSize; - index_t kBlockTileIterations; + int blkGroupSize; + int numBlockTileIteration; size_t gridSize; size_t gridSize_pre; @@ -245,97 +296,107 @@ struct DeviceReduceMultiBlockAtomicAdd struct Invoker : public BaseInvoker { - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - const auto in_grid_desc_m_k = DeviceReduceMultiBlockAtomicAdd::MakeSrc2dDescriptor( - arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); - const auto out_grid_desc_m = DeviceReduceMultiBlockAtomicAdd::MakeDst1dDescriptor( + const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m = + DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet( arg.outLengths_, arg.outStrides_); - using InGridDesc_M_K = decltype(in_grid_desc_m_k); - using OutGridDesc_M = decltype(out_grid_desc_m); - using GridwiseReduce = - GridwiseReduction_mk_to_m_multiblock_atomic_add; + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + using OutGridDesc_M_2 = decltype(out_grid_desc_m_2); + + using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock; + + const auto kernel_main = kernel_reduce_multiblock; float avg_time = 0; - KernelTimer timer; - - const auto kernel_pre = kernel_buffer_set_value; - const auto kernel_main = kernel_reduce_multiblock_atocmi_add; - - printf("launch_and_time_kernel: grid_dim {%ld, 1, 1}, block_dim {%d, 1, 1} \n", - arg.gridSize, - BlockSize); - printf("Warm up\n"); - - for(int i = 0; i < nrepeat + 1; i++) + if constexpr(use_multiblock) { - if(i == 1) - timer.Start(); + const auto identityVal = + ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); - launch_kernel(kernel_pre, - dim3(arg.gridSize_pre), - dim3(BlockSize), - 0, - out_grid_desc_m, - arg.out_dev_, - static_cast(0.0f)); + const auto kernel_pre = + kernel_buffer_set_value; - launch_kernel(kernel_main, - dim3(arg.gridSize), - dim3(BlockSize), - 0, - in_grid_desc_m_k, - out_grid_desc_m, - arg.in_elementwise_op_, - arg.acc_elementwise_op_, - arg.blkGroupSize, - arg.kBlockTileIterations, - arg.alpha_, - arg.in_dev_, - arg.out_dev_); + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + out_grid_desc_m_2, + arg.out_dev_, + identityVal); }; - timer.End(); - - avg_time = timer.GetElapsedTime() / nrepeat; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.in_index_dev_, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); return (avg_time); }; - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); }; }; - bool IsSupportedArgument(const BaseArgument* p_arg) override + static bool IsSupportedArgument(const Argument* pArg) { - const Argument* pArg = dynamic_cast(p_arg); + if constexpr(use_multiblock) + { + if(static_cast(pArg->beta_) != 0.0f) + return (false); + }; if constexpr(InSrcVectorDim == 0) { @@ -361,36 +422,48 @@ struct DeviceReduceMultiBlockAtomicAdd return (false); }; - if(static_cast(pArg->beta_) != 0.0f) - return (false); - // To improve if(pArg->invariant_lowest_length % OutDstVectorSize != 0) return (false); - // cases with small reduce_total_length should be handled by the BlockWise method - if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) - return (false); + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); - // This is very strong restriction, but needed to avoid some failure - if(pArg->invariant_lowest_length % M_BlockTileSize != 0) - return (false); + // This is very strong restriction, but needed to avoid some failure + if(pArg->invariant_lowest_length % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + // if(pArg->reduce_total_length / KThreadSliceSize < 2) + // return (false); + }; return (true); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); }; std::unique_ptr - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, const std::vector reduceDims, float alpha, float beta, const void* in_dev, + const void* in_index_dev, void* out_dev, - void* out_indices_dev, - void* workspace_dev, + void* out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) override { @@ -402,9 +475,9 @@ struct DeviceReduceMultiBlockAtomicAdd alpha, beta, static_cast(in_dev), + static_cast(in_index_dev), static_cast(out_dev), - static_cast(out_indices_dev), - static_cast(workspace_dev), + static_cast(out_index_dev), in_elementwise_op, acc_elementwise_op); }; @@ -419,7 +492,7 @@ struct DeviceReduceMultiBlockAtomicAdd auto str = std::stringstream(); // clang-format off - str << "DeviceReduceMultiBlockAtomicAdd<" << BlockSize << ","; + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ","; str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp deleted file mode 100644 index d583f7f1b8..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_reduce_multiblock_partial_reduce.hpp +++ /dev/null @@ -1,439 +0,0 @@ -#ifndef DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP -#define DEVICE_REDUCE_MULTIBLOCK_PARTIAL_REDUCE_HPP - -#include -#include -#include "device.hpp" -#include "device_reduce.hpp" -#include "device_reduce_common.hpp" -#include "gridwise_2d_reduction_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template -struct DeviceReduceMultiBlockPartialReduce - : public DeviceReduce -{ - static_assert(Rank <= 6, "Bigger Rank size is not supported!"); - static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, - "Invalid thread cluster size assignments!"); - - static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || - (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!"); - - using IndexDataType = int32_t; - - static constexpr index_t NumInvariantDim = Rank - NumReduceDim; - - static constexpr index_t numSrcDim = Rank; - static constexpr index_t numDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; - static constexpr bool reduceAllDim = (NumInvariantDim == 0); - - static constexpr int M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr int K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - static constexpr int MaxBlockGroupSize = 256; - - long_index_t GetWorkspaceSizeInBytes(const std::vector inLengths, - const std::vector reduceDims) override - { - size_t invariant_total_length; - size_t reduce_total_length; - - auto inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); - - std::tie(invariant_total_length, reduce_total_length) = - get_2d_lengths(inLengths_); - - int iterations = 1; - while(true) - { - int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - - if(testBlkGroupSize <= MaxBlockGroupSize) - break; - - iterations++; - }; - - int blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - - long_index_t workspace_size = invariant_total_length * blkGroupSize; - - long_index_t wsSizeInBytes = - !NeedIndices - ? workspace_size * sizeof(AccDataType) - : workspace_size * (sizeof(AccDataType) + sizeof(int32_t)) + 64 + sizeof(int); - - return (wsSizeInBytes); - }; - - bool HasFurtherCall() override { return (true); }; - - static auto MakeSrc2dDescriptor(const std::vector& inLengths, - const std::vector& inStrides, - int blkGroupSize, - int kBlockTileIterations) - { - const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); - const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); - - const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - - const auto in_grid_desc_m_k = [&]() { - if constexpr(reduceAllDim) - { - const auto one_dim_inDesc = transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(tupleSrcLengths)), - make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), - make_tuple(Sequence<0>{})); - - return transform_tensor_descriptor(one_dim_inDesc, - make_tuple(make_unmerge_transform(make_tuple( - 1, one_dim_inDesc.GetLength(Number<0>{})))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1>{})); - } - else - { - using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; - using ReduceDims = typename arithmetic_sequence_gen::type; - - const auto reduceDimLengths = - make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); - const auto invariantDimLengths = - make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); - - return transform_tensor_descriptor( - inDesc, - make_tuple(make_merge_transform(invariantDimLengths), - make_merge_transform(reduceDimLengths)), - make_tuple(InvariantDims{}, ReduceDims{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - }(); - - const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); - const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const int reduceSizePerBlock = K_BlockTileSize * kBlockTileIterations; - const auto inPad_M = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; - - auto in_grid_desc_m_k_padded = transform_tensor_descriptor( - in_grid_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, inPad_M), - make_right_pad_transform(reduceLength, inPad_K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (in_grid_desc_m_k_padded); - }; - - static auto MakeWorkspace2dDescriptor(int invariantLength, int blkGroupSize) - { - auto ws_desc_m_k = - make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); - - const auto wsPad = - math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; - - auto ws_desc_m_k_padded = - transform_tensor_descriptor(ws_desc_m_k, - make_tuple(make_right_pad_transform(invariantLength, wsPad), - make_pass_through_transform(blkGroupSize)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return (ws_desc_m_k_padded); - }; - - struct Argument : public BaseArgument - { - Argument(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, - const std::vector reduceDims, - float alpha, - float beta, - const InDataType* in_dev, - OutDataType* out_dev, - IndexDataType* out_indices_dev, - AccDataType* workspace_dev, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op) - : outLengths_{outLengths}, - outStrides_{outStrides}, - in_dev_{in_dev}, - out_dev_{out_dev}, - out_indices_dev_{out_indices_dev}, - workspace_dev_{workspace_dev}, - in_elementwise_op_{in_elementwise_op}, - acc_elementwise_op_{acc_elementwise_op} - { - inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); - inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); - - alpha_ = type_convert(alpha); - beta_ = type_convert(beta); - - std::tie(invariant_total_length, reduce_total_length) = - get_2d_lengths(inLengths_); - - if constexpr(NumInvariantDim == 0) - invariant_lowest_length = 1; - else - invariant_lowest_length = inLengths_[NumInvariantDim - 1]; - - reduce_lowest_length = inLengths_[Rank - 1]; - - int iterations = 1; - while(true) - { - int testBlkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - - if(testBlkGroupSize <= MaxBlockGroupSize) - break; - - iterations++; - }; - - blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / - (K_BlockTileSize * iterations); - - kBlockTileIterations = iterations; - - gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / - M_BlockTileSize * blkGroupSize; - - size_t ws_buf2_bytes_offset = math::integer_least_multiple( - invariant_total_length * blkGroupSize * sizeof(AccDataType), 64); - - if constexpr(NeedIndices) - workspace_indices_dev_ = reinterpret_cast( - reinterpret_cast(workspace_dev_) + ws_buf2_bytes_offset); - else - workspace_indices_dev_ = nullptr; - } - - std::vector inLengths_; - std::vector inStrides_; - std::vector outLengths_; - std::vector outStrides_; - - AccDataType alpha_; - AccDataType beta_; - - const InDataType* in_dev_; - OutDataType* out_dev_; - IndexDataType* out_indices_dev_; - AccDataType* workspace_dev_; - IndexDataType* workspace_indices_dev_; - - InElementwiseOperation in_elementwise_op_; - AccElementwiseOperation acc_elementwise_op_; - - int invariant_lowest_length; - int reduce_lowest_length; - size_t invariant_total_length; - size_t reduce_total_length; - - index_t blkGroupSize; - index_t kBlockTileIterations; - size_t gridSize; - }; - - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, int nrepeat = 1) - { - const auto in_grid_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeSrc2dDescriptor( - arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.kBlockTileIterations); - const auto ws_desc_m_k = DeviceReduceMultiBlockPartialReduce::MakeWorkspace2dDescriptor( - arg.invariant_total_length, arg.blkGroupSize); - using InGridDesc_M_K = decltype(in_grid_desc_m_k); - using WorkspaceDesc_M_K = decltype(ws_desc_m_k); - - using GridwiseReduce = - GridwiseReduction_mk_to_mk_multiblock_partial_reduce; - - float avg_time = 0; - - const auto kernel = kernel_partial_reduce_multiblock; - - avg_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(arg.gridSize), - dim3(BlockSize), - 0, - in_grid_desc_m_k, - ws_desc_m_k, - arg.in_elementwise_op_, - arg.acc_elementwise_op_, - arg.blkGroupSize, - arg.kBlockTileIterations, - arg.in_dev_, - arg.workspace_dev_, - arg.workspace_indices_dev_); - - return (avg_time); - }; - - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - }; - }; - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - if constexpr(OutDstVectorSize != 1) - return (false); - - if constexpr(InSrcVectorDim == 0) - { - if constexpr(NumInvariantDim == 0) - { - return (false); - } - else - { - if(pArg->inStrides_[NumInvariantDim - 1] != 1) - return (false); - - if(pArg->invariant_lowest_length % InSrcVectorSize != 0) - return (false); - }; - } - else - { - if(pArg->inStrides_[Rank - 1] != 1) - return (false); - - if(pArg->reduce_lowest_length % InSrcVectorSize != 0) - return (false); - }; - - // cases with small reduce_total_length should be handled by the BlockWise method - if(pArg->reduce_total_length <= BlockSize * KThreadSliceSize) - return (false); - - return (true); - }; - - std::vector GetWorkspace2dLengths(const BaseArgument* p_arg) override - { - const Argument* pArg = dynamic_cast(p_arg); - - return ( - std::vector{static_cast(pArg->invariant_total_length), pArg->blkGroupSize}); - }; - - std::unique_ptr - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, - const std::vector reduceDims, - float alpha, - float beta, - const void* in_dev, - void* out_dev, - void* out_indices_dev, - void* workspace_dev, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op) override - { - return std::make_unique(inLengths, - inStrides, - outLengths, - outStrides, - reduceDims, - alpha, - beta, - static_cast(in_dev), - static_cast(out_dev), - static_cast(out_indices_dev), - static_cast(workspace_dev), - in_elementwise_op, - acc_elementwise_op); - }; - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(); - }; - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceReduceMultiBlockPartialReduce<" << BlockSize << ","; - str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; - str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; - str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp index bf4088a96b..9549bf65d2 100644 --- a/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/device/device_reduce_threadwise.hpp @@ -6,6 +6,7 @@ #include "device.hpp" #include "device_reduce.hpp" #include "device_reduce_common.hpp" +#include "gridwise_2d_reduction_multiblock.hpp" #include "gridwise_2d_reduction_threadwise.hpp" namespace ck { @@ -19,22 +20,19 @@ template -struct DeviceReduceThreadWise : public DeviceReduce +struct DeviceReduceThreadWise : public DeviceReduce { static_assert(Rank <= 6, "Bigger Rank size is not supported!"); - static_assert((BlockSize == MThreadClusterSize) && (KThreadClusterSize == 1), - "Threadwise can only be called with KThreadClusterSize be 1 !"); static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && @@ -43,7 +41,7 @@ struct DeviceReduceThreadWise : public DeviceReduce& inLengths, - const std::vector& inStrides) + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides) { const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); @@ -114,8 +112,8 @@ struct DeviceReduceThreadWise : public DeviceReduce& outLengths, - const std::vector& outStrides) + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) { const auto tupleDstLengths = make_tuple_from_array(outLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(outStrides, Number{}); @@ -143,30 +141,26 @@ struct DeviceReduceThreadWise : public DeviceReduce inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, const std::vector reduceDims, float alpha, float beta, const InDataType* in_dev, OutDataType* out_dev, - IndexDataType* out_indices_dev, - AccDataType* workspace_dev, + IndexDataType* out_index_dev, const InElementwiseOperation in_elementwise_op, - const OutElementwiseOperation acc_elementwise_op) + const AccElementwiseOperation acc_elementwise_op) : outLengths_{outLengths}, outStrides_{outStrides}, in_dev_{in_dev}, out_dev_{out_dev}, - out_indices_dev_{out_indices_dev}, + out_index_dev_{out_index_dev}, in_elementwise_op_{in_elementwise_op}, acc_elementwise_op_{acc_elementwise_op} - { - (void)workspace_dev; - inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); @@ -183,36 +177,39 @@ struct DeviceReduceThreadWise : public DeviceReduce inLengths_; - std::vector inStrides_; - std::vector outLengths_; - std::vector outStrides_; + std::vector inLengths_; + std::vector inStrides_; + std::vector outLengths_; + std::vector outStrides_; AccDataType alpha_; AccDataType beta_; const InDataType* in_dev_; OutDataType* out_dev_; - IndexDataType* out_indices_dev_; + IndexDataType* out_index_dev_; InElementwiseOperation in_elementwise_op_; - OutElementwiseOperation acc_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; - int invariant_lowest_length; - int reduce_lowest_length; - size_t invariant_total_length; - size_t reduce_total_length; + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + int numBlockTileIteration; size_t gridSize; }; struct Invoker : public BaseInvoker { - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { const auto in_grid_desc_m_k = DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); @@ -221,30 +218,30 @@ struct DeviceReduceThreadWise : public DeviceReduce; - float avg_time = 0; + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise; + const auto kernel = kernel_reduce_threadwise; + AccElementwiseOperation>; - avg_time = launch_and_time_kernel(kernel, - nrepeat, + avg_time = launch_and_time_kernel(stream_config, + kernel, dim3(arg.gridSize), dim3(BlockSize), 0, @@ -265,16 +262,18 @@ struct DeviceReduceThreadWise : public DeviceReduce(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); }; }; @@ -310,9 +309,7 @@ struct DeviceReduceThreadWise : public DeviceReduceinvariant_lowest_length % OutDstVectorSize != 0) return (false); - // TODO: remove this. Should return true, as long as this DeviceOP instance support this - // case for bigger reduce_total_length size, we are supposed to use BlockWise method for - // better performance + // cases with big reduce_total_length should be handled by Blockwise kernel if(pArg->reduce_total_length / KThreadSliceSize >= 32) return (false); @@ -320,20 +317,22 @@ struct DeviceReduceThreadWise : public DeviceReduce - MakeArgumentPointer(const std::vector inLengths, - const std::vector inStrides, - const std::vector outLengths, - const std::vector outStrides, + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector outLengths, + const std::vector outStrides, const std::vector reduceDims, float alpha, float beta, const void* in_dev, + const void* in_index_dev, void* out_dev, - void* out_indices_dev, - void* workspace_dev, + void* out_index_dev, const InElementwiseOperation in_elementwise_op, - const OutElementwiseOperation acc_elementwise_op) override + const AccElementwiseOperation acc_elementwise_op) override { + (void)in_index_dev; + return std::make_unique(inLengths, inStrides, outLengths, @@ -343,8 +342,7 @@ struct DeviceReduceThreadWise : public DeviceReduce(in_dev), static_cast(out_dev), - static_cast(out_indices_dev), - static_cast(workspace_dev), + static_cast(out_index_dev), in_elementwise_op, acc_elementwise_op); }; @@ -359,9 +357,9 @@ struct DeviceReduceThreadWise : public DeviceReduce"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/include/ck/tensor_operation/gpu/device/device_softmax.hpp new file mode 100644 index 0000000000..f4ade54204 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -0,0 +1,203 @@ +#ifndef DEVICE_SOFTMAX_HPP +#define DEVICE_SOFTMAX_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_reduce.hpp" +#include "device_reduce_multiblock.hpp" +#include "device_reduce_common.hpp" +#include "gridwise_softmax.hpp" +#include "gridwise_set_buffer_value.hpp" +#include "reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmax : public BaseOperator +{ + using PassThrough = tensor_operation::element_wise::PassThrough; + + // Used for freeloading of some handy functions from DeviceReduceMultiBlock + using Reduction = DeviceReduceMultiBlock; // OutDstVectorSize + + using GridDesc_M_K = decltype(Reduction::MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseReduce = GridwiseSoftmax_mk_to_mk; + + struct Argument : public Reduction::Argument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + AccDataType alpha, + AccDataType beta, + const InDataType* in_dev, + OutDataType* out_dev) + : Reduction::Argument(inLengths, + inStrides, + {}, + {}, + reduceDims, + 0.0f, // alpha + 0.0f, // beta + in_dev, + nullptr, + out_dev, + nullptr, + PassThrough{}, + PassThrough{}), + // FIXME: The base class DeviceReduceMultiBlock::Argument only supports alpha/beta of + // float32 precision. Make it support any data type so the fields can be removed. + alpha_(alpha), + beta_(beta) + { + // std::cout << "blkGroupSize= " << this->blkGroupSize + // << ", numBlockTileIteration= " << this->numBlockTileIteration + // << ", gridSize=" << this->gridSize + // << ", invariant_total_length=" << this->invariant_total_length << + // std::endl; + } + + AccDataType alpha_; + AccDataType beta_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m_k = Reduction::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + + const auto kernel_main = + kernel_softmax; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m_k, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.beta_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if(!Reduction::IsSupportedArgument(p_arg_)) + { + return false; + } + + if(p_arg_->inLengths_[Rank - 1] % OutDstVectorSize != 0) + { + return false; + } + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + AccDataType alpha, + AccDataType beta, + const void* in_dev, + void* out_dev) + { + return std::make_unique(inLengths, + inStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev)); + }; + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceSoftmax<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif // DEVICE_SOFTMAX_HPP diff --git a/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp new file mode 100644 index 0000000000..4fcad7004f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_unary_elementwise.hpp @@ -0,0 +1,178 @@ +#pragma once +#include +#include + +#include "device.hpp" +#include "device_base.hpp" +#include "gridwise_unary_elementwise_1d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceUnaryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M0_1d(Desc_M0 desc_m0, index_t gridSize, index_t blockSize) + { + const auto m0 = desc_m0.GetLength(I0); + const index_t loop_step = gridSize * blockSize * ScalarPerVector; + const auto pad = math::integer_least_multiple(m0, loop_step) - m0; + const auto desc_m0_pad = + transform_tensor_descriptor(desc_m0, + make_tuple(make_right_pad_transform(m0, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m0_pad; + } + + static auto MakeDescriptor_M0(const std::vector& shape, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return shape[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(Dim > 1) + { + const auto desc_m0 = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M0_1d(desc_m0, gridSize, blockSize); + } + else + return PadDescriptor_M0_1d(desc, gridSize, blockSize); + } + + using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); + using GridwiseUEltwise = GridwiseUnaryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + BDataType* p_b, + const std::vector& shape, + const std::vector& stride_a, + const std::vector& stride_b, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + shape_(shape), + functor_(functor), + blockSize_(256) // FIXME - Calculate the grid size by number of CU in the future + { + index_t tensor_size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{}); + gridSize_ = GridwiseUEltwise::CalculateGridSize(tensor_size); + a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, blockSize_); + b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, blockSize_); + } + + const ADataType* p_a_; + BDataType* p_b_; + std::vector shape_; + GridDesc_M0 a_grid_desc_m0_; + GridDesc_M0 b_grid_desc_m0_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_unary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.a_grid_desc_m0_, + arg.b_grid_desc_m0_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->shape_.back() % ScalarPerVector != 0) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer(const void* p_a, + void* p_b, + std::vector shape, + std::vector stride_a, + std::vector stride_b, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + shape, + stride_a, + stride_b, + functor); + } + + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBinaryElementwise" + << "<" + << "ScalarPerVector = " << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp index 634e9212ea..4b3f52148d 100644 --- a/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp +++ b/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -29,6 +29,7 @@ #include "reduction_operator.hpp" #include "reduction_enums.hpp" #include "element_wise_operation.hpp" +#include namespace ck { @@ -37,77 +38,69 @@ namespace ck { // The boolean member "indexable" are also provided in reduce_binary_operactor for // easier checking by the upper-layer codes in the kernels. -template +template struct reduce_binary_operator; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Mul; - using dataType = T; + using opType = reduce::Mul; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Min; - using dataType = T; + using opType = reduce::Min; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Max; - using dataType = T; + using opType = reduce::Max; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::AMax; - using dataType = T; + using opType = reduce::AMax; static constexpr bool indexable = true; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; -template -struct reduce_binary_operator +template <> +struct reduce_binary_operator { - using opType = reduce::Add; - using dataType = T; + using opType = reduce::Add; static constexpr bool indexable = false; }; @@ -115,53 +108,101 @@ struct reduce_binary_operator // The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary // functor classes. // The two unary functors are called before and afer the Reduction is executed respectively -template +template struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; - using AccElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; - using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; -template -struct reduce_unary_operator +template <> +struct reduce_unary_operator { - using InElementwiseOperation = tensor_operation::element_wise::UnaryIdentic; - using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; }; } // end of namespace ck diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100644 index 0000000000..300ce6fc0a --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,215 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include "data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + // Question: should half_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + // Question: should bhalf_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } +}; + +struct Subtract +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 - x1; + }; + + // Question: should half_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 - x1; + }; + + // Question: should bhalf_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp - x2_tmp; + y = ck::type_convert(y_tmp); + } +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = static_cast(alpha_) * x0 + static_cast(beta_) * x1; + }; + + // Question: should half_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = static_cast(alpha_ * static_cast(x0) + beta_ * static_cast(x1)); + }; + + float alpha_; + float beta_; +}; + +struct AddRelu +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > 0.0 ? a : 0.0; + }; + + // Question: should half_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > static_cast(0.0f) ? a : static_cast(0.0f); + }; +}; + +struct AddHardswish +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + double a = x0 + x1; + double b = a + 3.0; + double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667; + y = c; + }; + + // Question: should half_t be supported ? + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + 3.0f; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index ab1cbfed45..274d398e26 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,108 +1,64 @@ #pragma once + #include "data_type.hpp" +#include "math_v2.hpp" +#include "unary_element_wise_operation.hpp" +#include "binary_element_wise_operation.hpp" namespace ck { namespace tensor_operation { namespace element_wise { -struct PassThrough -{ - __host__ __device__ void operator()(float& y, const float& x) const { y = x; } - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } - - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { y = x; } - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; } - - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } - - __host__ __device__ void operator()(double& y, const double& x) const { y = x; } -}; - -struct Add -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - y = x0 + x1; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - // FIXME - Use float (acc type) bias in the future. - y = x0 + x1; - } -}; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta) {} - - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - y = alpha_ * x0 + beta_ * x1; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - // FIXME - Let x0 be acc type - y = static_cast(alpha_ * static_cast(x0) + beta_ * static_cast(x1)); - } - - float alpha_; - float beta_; -}; - -struct AddRelu -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - const float a = x0 + x1; - y = a > 0 ? a : 0; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - const half_t a = x0 + x1; - y = a > 0 ? a : 0; - } -}; - -struct AddHardswish -{ - __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const - { - float a = x0 + x1; - float b = a + float{3}; - float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; - y = c; - } - - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1) const - { - float a = x0 + x1; - float b = a + float{3}; - float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; - y = c; - } -}; +// Need to ensure compiler will fail if there is no matching candidate, instead of compiler +// siliently do implicit type conversion +// +// Method 1: +// +// struct ExampleElementwiseOp +// { +// template +// __host__ __device__ constexpr void +// operator()(Y&, const X) const; +// +// template<> +// __host__ __device__ constexpr void +// operator()(half_t& y, const half_t& x) const +// { +// } +// }; +// +// Method 2: +// +// template +// struct ExampleElementwiseOp; +// +// template <> +// struct ExampleElementwiseOp +// { +// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const +// { +// } +// }; struct AddReluAdd { - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { half_t a = x0 + x1; half_t b = a > 0 ? a : 0; y = b + x2; } - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1, const float& x2) const + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const { float a = x0 + x1; float b = a > 0 ? a : 0; @@ -110,8 +66,9 @@ struct AddReluAdd y = c; } - __host__ __device__ constexpr void - operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const { float a = x0 + x1; float b = a > 0 ? a : 0; @@ -122,8 +79,14 @@ struct AddReluAdd struct AddHardswishAdd { - __host__ __device__ constexpr void - operator()(float& y, const float& x0, const float& x1, const float& x2) const + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const { float a = x0 + x1; float b = a + float{3}; @@ -132,8 +95,9 @@ struct AddHardswishAdd y = d; } - __host__ __device__ constexpr void - operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { float a = x0 + x1; float b = a + float{3}; @@ -143,190 +107,95 @@ struct AddHardswishAdd } }; -// Unary operators are usually called element-wisely before/after the reduction is executed on the -// elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 - -template -struct UnaryIdentic; - -template <> -struct UnaryIdentic +// C = A * B +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu { - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; + template + __host__ __device__ void operator()(E&, const C&, const D0&, const D1&) const; - __host__ __device__ void operator()(float& y, const float& x) const { y = x; }; + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + // Fast GeLU + // https://paperswithcode.com/method/gelu + // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) + const auto fast_gelu = [&](float x) { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + return x * cdf; + }; + + const float y = fast_gelu(c + float(d0) + float(d1)); + + e = type_convert(y); + } }; -template <> -struct UnaryIdentic +struct Normalize { - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; + // FIXME: is double absolutely necessary? + Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} - __host__ __device__ void operator()(float& y, const float& x) const + template + __host__ __device__ constexpr void operator()( + T& y, const T& x, const T& mean, const T& mean_square, const T& gamma, const T& beta) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const { - y = x / type_convert(divider_); + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + static_cast(epsilon_))) * gamma + beta; }; - int32_t divider_ = 1; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; }; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = x; }; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const + template <> + __host__ __device__ constexpr void operator()(double& y, + const double& x, + const double& mean, + const double& mean_square, + const double& gamma, + const double& beta) const { - y = x / type_convert(divider_); + using ck::math::sqrt; + + double variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; }; - int32_t divider_ = 1; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x; }; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(int32_t& y, const int32_t& x) const { y = x / divider_; }; - - int32_t divider_ = 1; -}; - -template <> -struct UnaryIdentic -{ - __host__ __device__ UnaryIdentic(const int8_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; }; -}; - -template -struct UnarySquare; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = x * x; }; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const - { - y = x * x / type_convert(divider_); - }; - - int32_t divider_ = 1; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = x * x; }; -}; - -template <> -struct UnarySquare -{ - __host__ __device__ UnarySquare(const int32_t divider = 1) { divider_ = divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const - { - y = x * x / type_convert(divider_); - }; - - int32_t divider_ = 1; + // FIXME: is double absolutely necessary? + double epsilon_; }; template -struct UnaryAbs; +struct UnaryTypeConvert; template <> -struct UnaryAbs +struct UnaryTypeConvert { - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = abs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = __habs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = abs(x); }; -}; - -template <> -struct UnaryAbs -{ - __host__ __device__ UnaryAbs(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const { - int8_t sgn = x >> (8 - 1); - - y = (x ^ sgn) - sgn; - }; -}; - -template -struct UnarySqrt; - -template <> -struct UnarySqrt -{ - __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(float& y, const float& x) const { y = sqrtf(x); }; + y = ck::type_convert(x); + } }; template <> -struct UnarySqrt +struct UnaryTypeConvert { - __host__ __device__ UnarySqrt(const int32_t divider = 1) { (void)divider; }; - - __host__ __device__ void operator()(double& y, const double& x) const { y = sqrt(x); }; + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const + { + y = ck::type_convert(x); + } }; } // namespace element_wise diff --git a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp deleted file mode 100644 index 038e36f564..0000000000 --- a/include/ck/tensor_operation/gpu/element/element_wise_reduce_operation.hpp +++ /dev/null @@ -1,10 +0,0 @@ -#pragma once -#include "data_type.hpp" - -namespace ck { -namespace tensor_operation { -namespace element_wise { - -} // namespace element_wise -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp new file mode 100644 index 0000000000..c6142474cc --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -0,0 +1,120 @@ +#pragma once + +#include "data_type.hpp" +#include "math_v2.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct PassThrough +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x; + }; +}; + +struct UnaryDivide +{ + __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::abs(x); + }; +}; + +struct UnarySqrt +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sqrt(x); + }; +}; + +struct Relu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_f32 = ck::type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = ck::type_convert(y_f32); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +struct FastGelu +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + const float u = float(2) * x * (float(0.035677) * x * x + float(0.797885)); + const float emu = exp(-u); + const float cdf = float(0.5) + float(0.5) * (float(2) / (float(1) + emu) - float(1)); + + y = x * cdf; + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp new file mode 100644 index 0000000000..792060ca86 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -0,0 +1,489 @@ +#ifndef UTILITY_BLOCK_TO_CTILE_MAP +#define UTILITY_BLOCK_TO_CTILE_MAP + +#include "utility/math.hpp" +#include "utility/number.hpp" +#include "tensor_description/tensor_adaptor.hpp" +#include "tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : M01_(M01), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// Blocks of row-vectors +template +struct BlockToCTileMap_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1) + : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1)); + UnderlyingMap underlying_map_; +}; + +// 2D slices of row-vectors in 3D space +template +struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1, + index_t KSplit = 1) + : M01_(M01), + N01_(N01), + KSplit_(KSplit), + underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + index_t KSplit) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KSplit), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_, KSplit_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); + UnderlyingMap underlying_map_; +}; + +template +__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) +{ + bool is_valid = false; + + const index_t m_block = c_tile_dim[Number<0>{}]; + const index_t n_block = c_tile_dim[Number<1>{}]; + + if constexpr(CTileIdx::Size() == 2) + { + const index_t m_block_idx = c_tile_idx[Number<0>{}]; + const index_t n_block_idx = c_tile_idx[Number<1>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + } + else if constexpr(CTileIdx::Size() == 3) + { + const index_t ksplit_idx = c_tile_idx[Number<0>{}]; + const index_t m_block_idx = c_tile_idx[Number<1>{}]; + const index_t n_block_idx = c_tile_idx[Number<2>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + ignore = ksplit_idx; + } + + return is_valid; +} + +} // namespace ck + +#endif // UTILITY_BLOCK_TO_CTILE_MAP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp deleted file mode 100644 index 6826d5211c..0000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_blockwise.hpp +++ /dev/null @@ -1,886 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP -#define CK_GRIDWISE_2D_REDUCTION_BLOCKWISE_HPP - -#include "data_type.hpp" -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_accumulate.hpp" -#include "reduction_functions_blockwise.hpp" -#include "reduction_functions_threadwise.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "cluster_descriptor.hpp" -#include "element_wise_operation.hpp" - -namespace ck { - -template -__global__ void kernel_reduce_blockwise(const InGridDesc_M_K in_grid_desc_m_k, - const OutGridDesc_M out_grid_desc_m, - const InElementwiseOperation in_elementwise_op, - const OutElementwiseOperation acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - const IndexDataType* const __restrict__ p_ws_indices_global, - IndexDataType* const __restrict__ p_indices_global) -{ - if constexpr(!NeedIndices) - { - constexpr bool IsSecondCall = false; - - GridwiseReduction::template Run(in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_global, - beta, - p_out_global, - p_ws_indices_global, - p_indices_global); - } - else - { - GridwiseReduction::RunWithIndex(in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_global, - beta, - p_out_global, - p_ws_indices_global, - p_indices_global); - }; -}; - -template -__global__ void -kernel_reduce_blockwise_second_call(const InGridDesc_M_K in_grid_desc_m_k, - const OutGridDesc_M out_grid_desc_m, - const InElementwiseOperation in_elementwise_op, - const OutElementwiseOperation acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - const IndexDataType* const __restrict__ p_ws_indices_global, - IndexDataType* const __restrict__ p_indices_global) -{ - if constexpr(!NeedIndices) - { - constexpr bool IsSecondCall = true; - - GridwiseReduction::template Run(in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_global, - beta, - p_out_global, - p_ws_indices_global, - p_indices_global); - } - else - { - GridwiseReduction::RunSecondCallWithIndex(in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - alpha, - p_in_global, - beta, - p_out_global, - p_ws_indices_global, - p_indices_global); - }; -}; - -template -struct GridwiseReduction_mk_to_m_blockwise -{ - static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || - (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && - (MThreadSliceSize % OutDstVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); - - using ThreadClusterLengths_M_K = Sequence; - - using ThreadBufferDimAccessOrder = - typename conditional, Sequence<0, 1>>::type; - - using ThreadClusterArrangeOrder = - typename conditional, Sequence<0, 1>>::type; - - static constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - - using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); - using ThreadReduceDstDesc_M = - decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - - using PassThroughOp = tensor_operation::element_wise::PassThrough; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - template - __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, - const OutGridDesc_M& out_grid_desc_m, - const InElementwiseOperation& in_elementwise_op, - const OutElementwiseOperation& acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - const IndexDataType* const __restrict__ p_ws_indices_global, - IndexDataType* const __restrict__ p_indices_global) - { - if constexpr(IsSecondCall) - { - static_assert(InSrcVectorDim == 1, - "InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"); - }; - - using BlockwiseReduce = PartitionedBlockwiseReduction; - - using ThreadwiseReduce = ThreadwiseReduction; - - (void)p_ws_indices_global; - (void)p_indices_global; - - // LDS - __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - const auto in_global_buf = make_dynamic_buffer( - p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert(zeroVal)); - auto out_global_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); - - auto reduce_work_buf = - make_dynamic_buffer(p_reduce_work_buffer, BlockSize); - - StaticBuffer - in_thread_buf; - - StaticBuffer accu_value_buf; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); - - const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * KThreadSliceSize)); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize; - - index_t reducedTiles = 0; - do - { - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); - }); - - ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); - - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - reducedTiles++; - } while(reducedTiles < toReduceTiles); - - constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - - static_for<0, MThreadSliceSize, 1>{}( - [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if(thread_k_cluster_id == 0) - { - acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); - - accu_value_buf(I) *= alpha; - } - }); - - if(thread_k_cluster_id == 0) - { - if constexpr(!BetaIsZero) - { - if(!float_equal_zero{}(beta)) - { - StaticBuffer - priorDstValueBuf; - - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - OutDstVectorSize, - 1, - false>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize)); - - threadwise_dst_load.Run(out_grid_desc_m, - out_global_buf, - reduced_data_desc, - make_tuple(I0), - priorDstValueBuf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; - }); - }; - }; - - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - true>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - threadwise_dst_store.Run( - reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf); - } - }; - - __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, - const OutGridDesc_M& out_grid_desc_m, - const InElementwiseOperation& in_elementwise_op, - const OutElementwiseOperation& acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - const IndexDataType* const __restrict__ p_ws_indices_global, - IndexDataType* const __restrict__ p_indices_global) - { - using BlockwiseReduceWithIndex = - PartitionedBlockwiseReductionWithIndex; - - using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; - - (void)p_ws_indices_global; - - // LDS - __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; - __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; - - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - const auto in_global_buf = make_dynamic_buffer( - p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert(zeroVal)); - auto out_global_val_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); - auto out_global_idx_buf = make_dynamic_buffer( - p_indices_global, out_grid_desc_m.GetElementSpaceSize()); - - auto reduce_work_val_buf = - make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); - auto reduce_work_idx_buf = - make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); - - StaticBuffer - in_thread_val_buf; - - StaticBuffer - in_thread_idx_buf; - - StaticBuffer accu_value_buf; - StaticBuffer accu_index_buf; - - const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * KThreadSliceSize)); - - index_t indexOffset = 0; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) = zeroVal; - accu_index_buf(I) = 0; - }); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize; - - index_t reducedTiles = 0; - do - { - // load the thread slice - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_val_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - - // initialize the indices for the per-thread to-reduce values - in_thread_idx_buf(Number{}) = - indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); - - // do element-wise pre-reduction operation - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); - - AccDataType tmpValue = zeroVal; - IndexDataType tmpIndex = 0; - - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - - AccumulationWithIndex::Calculate(tmpValue, - in_thread_val_buf[Number{}], - tmpIndex, - in_thread_idx_buf[Number{}]); - }); - - BlockwiseReduceWithIndex::Reduce( - reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); - - AccumulationWithIndex::Calculate( - accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); - }); - - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - indexOffset += K_BlockTileSize; - reducedTiles++; - } while(reducedTiles < toReduceTiles); - - constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if(thread_k_cluster_id == 0) - { - // for indiced operation, acc_elementwise_op shoud do nothing - acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); - - accu_value_buf(I) *= alpha; - } - }); - - if(thread_k_cluster_id == 0) - { - if constexpr(!BetaIsZero) - { - if(!float_equal_zero{}(beta)) - { - StaticBuffer - priorDstValueBuf; - - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - OutDstVectorSize, - 1, - false>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize)); - - threadwise_dst_load.Run(out_grid_desc_m, - out_global_val_buf, - reduced_data_desc, - make_tuple(I0), - priorDstValueBuf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; - }); - }; - }; - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - false>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - false>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - threadwise_dst_val_store.Run(reduced_data_desc, - make_tuple(I0), - accu_value_buf, - out_grid_desc_m, - out_global_val_buf); - threadwise_dst_idx_store.Run(reduced_data_desc, - make_tuple(I0), - accu_index_buf, - out_grid_desc_m, - out_global_idx_buf); - } - }; - - __device__ static void - RunSecondCallWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, - const OutGridDesc_M& out_grid_desc_m, - const InElementwiseOperation in_elementwise_op, - const OutElementwiseOperation acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_ws_values_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - const IndexDataType* const __restrict__ p_ws_indices_global, - IndexDataType* const __restrict__ p_indices_global) - { - static_assert(InSrcVectorDim == 1, - "InSrcVectorDim must be 1 for BlockwiseSecondCall, please check!"); - - using BlockwiseReduceWithIndex = - PartitionedBlockwiseReductionWithIndex, - ThreadClusterArrangeOrder, - ReduceOperation, - PropagateNan>; - - using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; - - (void)in_elementwise_op; - - // LDS - __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; - __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; - - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - const auto src_global_val_buf = - make_dynamic_buffer(p_ws_values_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(zeroVal)); - const auto src_global_idx_buf = make_dynamic_buffer( - p_ws_indices_global, in_grid_desc_m_k.GetElementSpaceSize()); - auto out_global_val_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); - auto out_global_idx_buf = make_dynamic_buffer( - p_indices_global, out_grid_desc_m.GetElementSpaceSize()); - - auto reduce_work_val_buf = - make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); - auto reduce_work_idx_buf = - make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); - - StaticBuffer - in_thread_val_buf; - - StaticBuffer - in_thread_idx_buf; - - StaticBuffer accu_value_buf; - StaticBuffer accu_index_buf; - - const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_1d_id = get_block_1d_id(); - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_val_load = - ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * KThreadSliceSize)); - - auto threadwise_src_idx_load = - ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - thread_k_cluster_id * KThreadSliceSize)); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) = zeroVal; - accu_index_buf(I) = 0; - }); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - const index_t toReduceTiles = (toReduceLength + K_BlockTileSize - 1) / K_BlockTileSize; - - index_t reducedTiles = 0; - do - { - // load the thread slice - threadwise_src_val_load.Run(in_grid_desc_m_k, - src_global_val_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_val_buf); - threadwise_src_idx_load.Run(in_grid_desc_m_k, - src_global_idx_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_idx_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - AccDataType tmpValue = zeroVal; - IndexDataType tmpIndex = 0; - - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - - AccumulationWithIndex::Calculate(tmpValue, - in_thread_val_buf[Number{}], - tmpIndex, - in_thread_idx_buf[Number{}]); - }); - - BlockwiseReduceWithIndex::Reduce( - reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); - - AccumulationWithIndex::Calculate( - accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); - }); - - threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - reducedTiles++; - } while(reducedTiles < toReduceTiles); - - constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if(thread_k_cluster_id == 0) - { - // for indiced operation, acc_elementwise_op shoud do nothing - acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); - - accu_value_buf(I) *= alpha; - } - }); - - if(thread_k_cluster_id == 0) - { - if constexpr(!BetaIsZero) - { - if(!float_equal_zero{}(beta)) - { - StaticBuffer - priorDstValueBuf; - - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - OutDstVectorSize, - 1, - true>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize)); - - threadwise_dst_load.Run(out_grid_desc_m, - out_global_val_buf, - reduced_data_desc, - make_tuple(I0), - priorDstValueBuf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; - }); - }; - }; - - auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - true>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - auto threadwise_dst_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - true>( - out_grid_desc_m, - make_multi_index(block_global_1d_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - threadwise_dst_val_store.Run(reduced_data_desc, - make_tuple(I0), - accu_value_buf, - out_grid_desc_m, - out_global_val_buf); - threadwise_dst_idx_store.Run(reduced_data_desc, - make_tuple(I0), - accu_index_buf, - out_grid_desc_m, - out_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp new file mode 100644 index 0000000000..4206a91406 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -0,0 +1,638 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP +#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_HPP + +#include "reduction_common.hpp" +#include "reduction_operator.hpp" +#include "reduction_functions_accumulate.hpp" +#include "reduction_functions_blockwise.hpp" +#include "reduction_functions_threadwise.hpp" + +#include "threadwise_tensor_slice_transfer.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + (void)p_in_index_global; + (void)p_out_index_global; + + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(block_group_size == 0 && !float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + false>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + } + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using BlockwiseReduceWithIndex = + PartitionedBlockwiseReductionWithIndex, + ThreadClusterArrangeOrder, + ReduceOperation, + PropagateNan>; + + using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; + + (void)in_elementwise_op; + + // LDS + __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; + __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_val_buf = + make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); + auto reduce_work_idx_buf = + make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_1d_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + index_t indexOffset = 0; + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + // initialize the indices for the per-thread to-reduce values + in_thread_idx_buf(Number{}) = + indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); + + // do element-wise pre-reduction operation + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexOffset += K_BlockTileSize; + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + }; + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + // for indiced operation, acc_elementwise_op shoud do nothing + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + threadwise_dst_idx_store.Run(reduced_data_desc, + make_tuple(I0), + accu_index_buf, + out_grid_desc_m, + out_global_idx_buf); + } + }; +}; + +} // namespace ck +#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp deleted file mode 100644 index 4e325f3573..0000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_atomic_add.hpp +++ /dev/null @@ -1,269 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP -#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_ATOMIC_ADD_HPP - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_accumulate.hpp" -#include "reduction_functions_blockwise.hpp" -#include "reduction_functions_threadwise.hpp" - -#include "threadwise_tensor_slice_transfer.hpp" -#include "element_wise_operation.hpp" - -namespace ck { - -template -__global__ void -kernel_reduce_multiblock_atocmi_add(const InGridDesc_M_K in_grid_desc_m_k, - const OutGridDesc_M out_grid_desc_m, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op, - index_t block_group_size, - index_t num_k_block_tile_iteration, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - OutDataType* const __restrict__ p_out_global) -{ - GridwiseReduction::Run(in_grid_desc_m_k, - out_grid_desc_m, - in_elementwise_op, - acc_elementwise_op, - block_group_size, - num_k_block_tile_iteration, - alpha, - p_in_global, - p_out_global); -}; - -template -struct GridwiseReduction_mk_to_m_multiblock_atomic_add -{ - static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || - (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && - (MThreadSliceSize % OutDstVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); - - using ThreadClusterLengths_M_K = Sequence; - - using ThreadBufferDimAccessOrder = - typename conditional, Sequence<0, 1>>::type; - - using ThreadClusterArrangeOrder = - typename conditional, Sequence<0, 1>>::type; - - static constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - - using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); - using ThreadReduceDstDesc_M = - decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - - using BlockwiseReduce = PartitionedBlockwiseReduction; - - using ThreadwiseReduce = ThreadwiseReduction; - - using PassThroughOp = tensor_operation::element_wise::PassThrough; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - using Accumulation = detail::AccumulateWithNanCheck; - - __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, - const OutGridDesc_M& out_grid_desc_m, - const InElementwiseOperation& in_elementwise_op, - const AccElementwiseOperation& acc_elementwise_op, - index_t block_group_size, - index_t num_k_block_tile_iteration, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - OutDataType* const __restrict__ p_out_global) - { - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - // LDS - __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - - const auto in_global_buf = make_dynamic_buffer( - p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert(zeroVal)); - auto out_global_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); - - auto reduce_work_buf = - make_dynamic_buffer(p_reduce_work_buffer, BlockSize); - - StaticBuffer - in_thread_buf; - - StaticBuffer accu_value_buf; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / block_group_size; - const index_t block_local_id = block_global_id % block_group_size; - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, - block_local_id * reduceSizePerBlock + - thread_k_cluster_id * KThreadSliceSize)); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - index_t reducedTiles = 0; - do - { - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); - }); - - ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); - - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - reducedTiles++; - } while(reducedTiles < num_k_block_tile_iteration); - - constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - - // Each block executes multiple parallel reductions on the LDS, and by atomic-adding its - // reduced output to the global location corresponding to each invariant dimension to get a - // consistent reduced result for that invariant dimension. due to the using of vector_load, - // each block/thread is involved into multiple invarirant dimensions. - static_for<0, MThreadSliceSize, 1>{}( - [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if(thread_k_cluster_id == 0) - { - acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); - - accu_value_buf(I) *= alpha; - } - }); - - if(thread_k_cluster_id == 0) - { - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::AtomicAdd, - 1, - true>( - out_grid_desc_m, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize), - PassThroughOp{}); - - threadwise_dst_store.Run( - reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp deleted file mode 100644 index d1be1f5275..0000000000 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock_partial_reduce.hpp +++ /dev/null @@ -1,487 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP -#define CK_GRIDWISE_2D_REDUCTION_MULTIBLOCK_PARTIAL_REDUCE_HPP - -#include "reduction_common.hpp" -#include "reduction_operator.hpp" -#include "reduction_functions_accumulate.hpp" -#include "reduction_functions_blockwise.hpp" -#include "reduction_functions_threadwise.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "cluster_descriptor.hpp" -#include "element_wise_operation.hpp" - -namespace ck { - -template -__global__ void -kernel_partial_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, - const WorkspaceDesc_M_K workspace_desc_m_k, - const InElementwiseOperation in_elementwise_op, - const AccElementwiseOperation acc_elementwise_op, - index_t block_group_size, - index_t num_k_block_tile_iteration, - const InDataType* const __restrict__ p_src_global, - AccDataType* const __restrict__ p_ws_values_global, - IndexDataType* const __restrict__ p_ws_indices_global) - -{ - if constexpr(!NeedIndices) - { - GridwiseReduction::Run(in_grid_desc_m_k, - workspace_desc_m_k, - in_elementwise_op, - acc_elementwise_op, - block_group_size, - num_k_block_tile_iteration, - p_src_global, - p_ws_values_global, - p_ws_indices_global); - } - else - { - GridwiseReduction::RunWithIndex(in_grid_desc_m_k, - workspace_desc_m_k, - in_elementwise_op, - acc_elementwise_op, - block_group_size, - num_k_block_tile_iteration, - p_src_global, - p_ws_values_global, - p_ws_indices_global); - }; -}; - -template -struct GridwiseReduction_mk_to_mk_multiblock_partial_reduce -{ - static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || - (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), - "Invalid thread slice sizes and/or vector sizes configuration, please check!"); - - static_assert(OutDstVectorSize == 1, "OutDstVectorSize must be 1 for MultiBlockPartialReduce!"); - - static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); - - using ThreadClusterLengths_M_K = Sequence; - - using ThreadBufferDimAccessOrder = - typename conditional, Sequence<0, 1>>::type; - - using ThreadClusterArrangeOrder = - typename conditional, Sequence<0, 1>>::type; - - static constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); - - using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); - using ThreadReduceDstDesc_M = - decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - - using PassThroughOp = tensor_operation::element_wise::PassThrough; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; - - __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, - const WorkspaceDesc_M_K& workspace_desc_m_k, - const InElementwiseOperation& in_elementwise_op, - const AccElementwiseOperation& acc_elementwise_op, - index_t block_group_size, - index_t num_k_block_tile_iteration, - const InDataType* const __restrict__ p_src_global, - AccDataType* const __restrict__ p_ws_values_global, - IndexDataType* const __restrict__ p_ws_indices_global) - { - using BlockwiseReduce = PartitionedBlockwiseReduction; - - using ThreadwiseReduce = ThreadwiseReduction; - - (void)p_ws_indices_global; - (void)acc_elementwise_op; - - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - // LDS - __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - - const auto in_global_buf = - make_dynamic_buffer(p_src_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(zeroVal)); - auto workspace_global_buf = make_dynamic_buffer( - p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); - - auto reduce_work_buf = - make_dynamic_buffer(p_reduce_work_buffer, BlockSize); - - StaticBuffer - in_thread_buf; - - StaticBuffer accu_value_buf; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / block_group_size; - const index_t block_local_id = block_global_id % block_group_size; - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, - block_local_id * reduceSizePerBlock + - thread_k_cluster_id * KThreadSliceSize)); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - index_t reducedTiles = 0; - do - { - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - in_elementwise_op(in_thread_buf(Number{}), - in_thread_buf(Number{})); - }); - }); - - ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); - - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - reducedTiles++; - } while(reducedTiles < num_k_block_tile_iteration); - - // Each block executes multiple parallel reductions on the LDS, and due to the using of - // vector_load, each block/thread is involved into multiple invarirant dimensions. - static_for<0, MThreadSliceSize, 1>{}( - [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); - - constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); - - if(thread_k_cluster_id == 0) - { - auto threadwise_workspace_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1>, - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - workspace_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - threadwise_workspace_store.Run(reduced_data_desc, - make_tuple(I0, I0), - accu_value_buf, - workspace_desc_m_k, - workspace_global_buf); - } - }; - - __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, - const WorkspaceDesc_M_K& workspace_desc_m_k, - const InElementwiseOperation& in_elementwise_op, - const AccElementwiseOperation& acc_elementwise_op, - index_t block_group_size, - index_t num_k_block_tile_iteration, - const InDataType* const __restrict__ p_src_global, - AccDataType* const __restrict__ p_ws_values_global, - IndexDataType* const __restrict__ p_ws_indices_global) - { - using BlockwiseReduceWithIndex = - PartitionedBlockwiseReductionWithIndex; - - using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; - - (void)acc_elementwise_op; - - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - // LDS - __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; - __shared__ index_t p_reduce_work_idx_buffer[BlockSize]; - - const auto in_global_buf = - make_dynamic_buffer(p_src_global, - in_grid_desc_m_k.GetElementSpaceSize(), - type_convert(zeroVal)); - auto workspace_global_val_buf = make_dynamic_buffer( - p_ws_values_global, workspace_desc_m_k.GetElementSpaceSize()); - auto workspace_global_idx_buf = make_dynamic_buffer( - p_ws_indices_global, workspace_desc_m_k.GetElementSpaceSize()); - - auto reduce_work_val_buf = - make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); - auto reduce_work_idx_buf = - make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); - - StaticBuffer - in_thread_val_buf; - StaticBuffer - in_thread_idx_buf; - - StaticBuffer accu_value_buf; - StaticBuffer accu_index_buf; - - const index_t thread_local_id = get_thread_local_1d_id(); - const index_t block_global_id = get_block_1d_id(); - const index_t blkgroup_id = block_global_id / block_group_size; - const index_t block_local_id = block_global_id % block_group_size; - - const auto thread_cluster_idx = - thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); - - const auto thread_m_cluster_id = thread_cluster_idx[I0]; - const auto thread_k_cluster_id = thread_cluster_idx[I1]; - - const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; - - using ThreadBufferLengths = Sequence; - constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, - block_local_id * reduceSizePerBlock + - thread_k_cluster_id * KThreadSliceSize)); - - constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); - - index_t indexOffset = block_local_id * reduceSizePerBlock; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) = zeroVal; - accu_index_buf(I) = 0; - }); - - index_t reducedTiles = 0; - do - { - // load the thread slice - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_val_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - - // initialize the indices for the per-thread to-reduce values - in_thread_idx_buf(Number{}) = - indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); - - // do element-wise pre-reduction operation - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); - }); - - AccDataType tmpValue = zeroVal; - IndexDataType tmpIndex = 0; - - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); - - AccumulationWithIndex::Calculate(tmpValue, - in_thread_val_buf[Number{}], - tmpIndex, - in_thread_idx_buf[Number{}]); - }); - - BlockwiseReduceWithIndex::Reduce( - reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); - - AccumulationWithIndex::Calculate( - accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); - }); - - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - - indexOffset += K_BlockTileSize; - - reducedTiles++; - } while(reducedTiles < num_k_block_tile_iteration); - - constexpr auto reduced_data_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); - - if(thread_k_cluster_id == 0) - { - auto threadwise_workspace_val_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1>, - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - workspace_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - auto threadwise_workspace_idx_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1>, - 1, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>( - workspace_desc_m_k, - make_multi_index(blkgroup_id * M_BlockTileSize + - thread_m_cluster_id * MThreadSliceSize, - block_local_id), - PassThroughOp{}); - - threadwise_workspace_val_store.Run(reduced_data_desc, - make_tuple(I0, I0), - accu_value_buf, - workspace_desc_m_k, - workspace_global_val_buf); - threadwise_workspace_idx_store.Run(reduced_data_desc, - make_tuple(I0, I0), - accu_index_buf, - workspace_desc_m_k, - workspace_global_idx_buf); - } - }; -}; - -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp index c047f7e375..d6e4bbd4cb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -37,7 +37,8 @@ namespace ck { template (in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); }; }; @@ -91,11 +93,9 @@ template ; - (void)p_indices_global; + const auto identityVal = ReduceOperation::template GetIdentityValue(); - const auto zeroVal = ReduceOperation::GetReductionZeroVal(); - - const auto in_global_buf = make_dynamic_buffer( - p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert(zeroVal)); + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); auto dst_global_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); StaticBuffer in_thread_buf; StaticBuffer accu_value_buf; - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = zeroVal; }); + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); @@ -160,28 +159,29 @@ struct GridwiseReduction_mk_to_m_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); index_t reducedLength = 0; do { - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_buf); + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // do element-wise pre-reduction operation @@ -194,7 +194,7 @@ struct GridwiseReduction_mk_to_m_threadwise ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); reducedLength += KThreadSliceSize; } while(reducedLength < toReduceLength); @@ -207,68 +207,65 @@ struct GridwiseReduction_mk_to_m_threadwise constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - if constexpr(!BetaIsZero) + if(!float_equal_zero{}(beta)) { - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - true>( - out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + true>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); - StaticBuffer - priorDstValue_buf; + StaticBuffer + priorDstValue_buf; - threadwise_dst_load.Run(out_grid_desc_m, - dst_global_buf, - reduced_data_desc, - make_tuple(I0), - priorDstValue_buf); + threadwise_dst_load.Run(out_grid_desc_m, + dst_global_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; - }); - }; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); }; - auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0>, - 0, - OutDstVectorSize, - InMemoryDataOperationEnum::Set, - 1, - false>( - out_grid_desc_m, - make_multi_index(thread_global_1d_id * MThreadSliceSize), - PassThroughOp{}); + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); threadwise_dst_store.Run( reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); }; - __device__ static void RunWithIndices(const InGridDesc_M_K& in_grid_desc_m_k, - const OutGridDesc_M& out_grid_desc_m, - const InElementwiseOperation& in_elementwise_op, - const AccElementwiseOperation& acc_elementwise_op, - AccDataType alpha, - const InDataType* const __restrict__ p_in_global, - AccDataType beta, - OutDataType* const __restrict__ p_out_global, - IndexDataType* const __restrict__ p_indices_global) + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) { using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); - const auto in_global_buf = make_dynamic_buffer( - p_in_global, in_grid_desc_m_k.GetElementSpaceSize(), type_convert(zeroVal)); auto out_global_val_buf = make_dynamic_buffer( - p_out_global, out_grid_desc_m.GetElementSpaceSize()); + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); auto out_global_idx_buf = make_dynamic_buffer( - p_indices_global, out_grid_desc_m.GetElementSpaceSize()); + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); StaticBuffer in_thread_val_buf; @@ -301,7 +303,7 @@ struct GridwiseReduction_mk_to_m_threadwise StaticBuffer accu_index_buf; static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) = zeroVal; + accu_value_buf(I) = identityVal; accu_index_buf(I) = 0; }); @@ -313,50 +315,105 @@ struct GridwiseReduction_mk_to_m_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); - auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( - in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); index_t indexStart = 0; index_t reducedLength = 0; - do + if constexpr(HaveIndexInput) { - threadwise_src_load.Run(in_grid_desc_m_k, - in_global_buf, - thread_buffer_desc, - make_tuple(I0, I0), - in_thread_val_buf); + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - // do element-wise pre-reduction operation - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); - in_thread_idx_buf(Number{}) = indexStart + iK(); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); - in_elementwise_op(in_thread_val_buf(Number{}), - in_thread_val_buf(Number{})); + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); }); - }); - ThreadwiseReduceWithIndex::Reduce( - in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); - threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); - indexStart += KThreadSliceSize; - reducedLength += KThreadSliceSize; - } while(reducedLength < toReduceLength); + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + } + else + { + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_thread_idx_buf(Number{}) = indexStart + iK(); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + }; // for indiced operation, acc_elementwise_op shoud do nothing static_for<0, MThreadSliceSize, 1>{}([&](auto I) { @@ -367,36 +424,32 @@ struct GridwiseReduction_mk_to_m_threadwise constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; - if constexpr(!BetaIsZero) + if(!float_equal_zero{}(beta)) { - if(!float_equal_zero{}(beta)) - { - auto threadwise_dst_load = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0>, - 0, - 1, - 1, - false>( - out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + false>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); - StaticBuffer - priorDstValue_buf; + StaticBuffer + priorDstValue_buf; - threadwise_dst_load.Run(out_grid_desc_m, - out_global_val_buf, - reduced_data_desc, - make_tuple(I0), - priorDstValue_buf); + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; - }); - }; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); }; auto threadwise_dst_val_store = @@ -409,7 +462,7 @@ struct GridwiseReduction_mk_to_m_threadwise Sequence<0>, 0, OutDstVectorSize, - InMemoryDataOperationEnum::Set, + OutMemoryDataOperation, 1, false>( out_grid_desc_m, @@ -426,7 +479,7 @@ struct GridwiseReduction_mk_to_m_threadwise Sequence<0>, 0, OutDstVectorSize, - InMemoryDataOperationEnum::Set, + OutMemoryDataOperation, 1, false>( out_grid_desc_m, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp new file mode 100644 index 0000000000..d3342b072e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp @@ -0,0 +1,251 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) +{ + Gridwise5AryEltwise::Run(p_a_global, + p_b_global, + p_c_global, + p_d_global, + p_e_global, + p_f_global, + a_grid_desc_m, + b_grid_desc_m, + c_grid_desc_m, + d_grid_desc_m, + e_grid_desc_m, + f_grid_desc_m, + functor); +} + +// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs +template +struct Gridwise5AryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * MPerThread); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m.GetElementSpaceSize()); + const auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m.GetElementSpaceSize()); + const auto d_global_buf = make_dynamic_buffer( + p_d_global, d_grid_desc_m.GetElementSpaceSize()); + const auto e_global_buf = make_dynamic_buffer( + p_e_global, e_grid_desc_m.GetElementSpaceSize()); + auto f_global_buf = make_dynamic_buffer( + p_f_global, f_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + StaticBuffer d_thread_buf; + StaticBuffer e_thread_buf; + StaticBuffer f_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; + + auto c_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + CScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{c_grid_desc_m, thread_store_global_offset}; + + auto d_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + DScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{d_grid_desc_m, thread_store_global_offset}; + + auto e_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + EScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{e_grid_desc_m, thread_store_global_offset}; + + auto f_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + FScalarPerVector, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + f_grid_desc_m, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = M / (loop_step); + do + { + // read and process MPerThread elements + a_global_load.Run( + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); + + c_global_load.Run( + c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf); + + d_global_load.Run( + d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf); + + e_global_load.Run( + e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf); + + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); + functor(f_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{}), + c_thread_buf(Number{}), + d_thread_buf(Number{}), + e_thread_buf(Number{})); + }); + + f_global_write.Run(thread_desc_m, + make_tuple(I0), // SrcSliceOriginIdx + f_thread_buf, + f_grid_desc_m, + f_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index); + d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index); + e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index); + f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp new file mode 100644 index 0000000000..374c4fe59a --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_binary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const ElementwiseFunctor functor) +{ + GridwiseBinEltwise::Run( + p_a_global, p_b_global, p_c_global, a_grid_desc_m, b_grid_desc_m, c_grid_desc_m, functor); +} + +template +struct GridwiseBinaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * MPerThread); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + CDataType* __restrict__ p_c_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; + + auto c_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + CScalarPerVector, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + c_grid_desc_m, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = M / (loop_step); + do + { + // read and process MPerThread elements + a_global_load.Run( + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); + + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); + functor(c_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{})); + }); + + c_global_write.Run(thread_desc_m, + make_tuple(I0), // SrcSliceOriginIdx + c_thread_buf, + c_grid_desc_m, + c_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_write.MoveDstSliceWindow(c_grid_desc_m, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000..0b790d4e38 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,989 @@ +#pragma once +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" +#include "reduction_functions_threadwise.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_grid, + const FloatC1* __restrict__ p_c1_grid, + DPtrsGlobal p_ds_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const DxsInElementwiseOperation dxs_in_element_op, + const DxsReduceAccElementwiseOperation dxs_out_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_ds_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + dxs_in_element_op, + dxs_out_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = p_ds_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c1_element_op; + ignore = dxs_in_element_op; + ignore = dxs_out_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = d_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeDGridDescriptor_MBlock_MPerBlock(const DGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto d_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return d_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_grid, + const FloatC1* __restrict__ p_c1_grid, + DPtrsGlobal p_ds_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const C1ElementwiseOperation& c1_element_op, + const DxsInElementwiseOperation& dxs_in_element_op, + const DxsReduceAccElementwiseOperation& dxs_out_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock& d_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR d_reduce_thread_desc_mblock_mperblock + constexpr auto d_reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_d_grid = p_ds_grid[I]; + auto d_out_element_op = dxs_out_element_op[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(d_reduce_thread_desc_mblock_mperblock), + decltype(d_grid_desc_mblock_mperblock), + decltype(d_out_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + DGlobalMemoryDataOperation::At(I), + 1, + false>{d_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + d_out_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatReduceAcc, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC1, + FloatReduceAcc, + decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatC, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_buf, + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_buf, + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c1_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + auto& p_d_grid = p_ds_grid[In]; + + auto d_grid_buf = make_dynamic_buffer( + p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto d_thread_buf = + make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& d_in_element_op = dxs_in_element_op[In]; + + auto& d_reduce_thread_copy_vgpr_to_global = + dxs_reduce_thread_copy_vgpr_to_global(In); + + using DReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto d_zeroVal = + DReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { d_thread_buf(I) = d_zeroVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + d_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); + + // copy from VGPR to Global + d_reduce_thread_copy_vgpr_to_global.Run( + d_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + d_thread_buf, + d_grid_desc_mblock_mperblock, + d_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + d_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C1 + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // Reduction + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp similarity index 57% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1a66c8ff3f..3b5daf6ead 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dlops_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -1,38 +1,38 @@ -#ifndef CK_GRIDWISE_GEMM_V1R3_HPP -#define CK_GRIDWISE_GEMM_V1R3_HPP +#pragma once #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_dlops_v2r3.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_dl_v2r3.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp" -#include "threadwise_tensor_slice_transfer_v2.hpp" +#include "threadwise_tensor_slice_transfer.hpp" #include "threadwise_tensor_slice_set.hpp" +#include "element_wise_operation.hpp" namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dlops_v1r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor cblockid_to_m0_n0_block_cluster_adaptor) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -43,10 +43,10 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_k0_m0_m1_k1_grid_desc, - b_k0_n0_n1_k1_grid_desc, - c_m0_m10_m11_n0_n10_n11_grid_desc, - cblockid_to_m0_n0_block_cluster_adaptor, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, integral_constant{}, integral_constant{}); } @@ -56,12 +56,12 @@ template -struct GridwiseGemmDlops_km_kn_mn_v1r3 + index_t CThreadTransferDstScalarPerVector> +struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -97,7 +92,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 static constexpr auto I3 = Number<3>{}; // K1 should be Number<...> - static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + static constexpr auto K1 = AGridDesc_K0_M_K1{}.GetLength(I2); __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -106,112 +101,112 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_k_m_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_k_n_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); } __host__ __device__ static constexpr bool - CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, - const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); } __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) { - const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); return grid_size; } __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) { - const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; return has_main_k_block_loop; } __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) { - const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; return has_double_tail_k_block_loop; } __host__ __device__ static constexpr auto - MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) { - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto M1 = Number{}; + const auto M1 = Number{}; const auto M0 = M / M1; - const auto a_k0_m0_m1_k1_grid_desc = - transform_tensor_descriptor(a_k0_m_k1_grid_desc, + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(M0, M1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return a_k0_m0_m1_k1_grid_desc; + return a_grid_desc_k0_m0_m1_k1; } __host__ __device__ static constexpr auto - MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) { - const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto N1 = Number{}; + const auto N1 = Number{}; const auto N0 = N / N1; - const auto b_k0_n0_n1_k1_grid_desc = - transform_tensor_descriptor(b_k0_n_k1_grid_desc, + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, make_tuple(make_pass_through_transform(K0), make_unmerge_transform(make_tuple(N0, N1)), make_pass_through_transform(K1)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return b_k0_n0_n1_k1_grid_desc; + return b_grid_desc_k0_n0_n1_k1; } __host__ __device__ static constexpr auto - MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; const auto M0 = M / M1; const auto N0 = N / N1; @@ -226,41 +221,29 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto M10 = M1 / M11; constexpr auto N10 = N1 / N11; - const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), make_unmerge_transform(make_tuple(N0, N10, N11))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); - return c_m0_m10_m11_n0_n10_n11_grid_desc; + return c_grid_desc_m0_m10_m11_n0_n10_n11; } + // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto - MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); } - using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); - using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); - using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); - using CBlockIdToM0N0BlockClusterAdaptor = - decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); template __device__ static void @@ -268,57 +251,64 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatAB* __restrict__ p_shared_block, - const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, - const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, - const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, - const CBlockIdToM0N0BlockClusterAdaptor& cblockid_to_m0_n0_block_cluster_adaptor, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, integral_constant, integral_constant) { const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); // divide block work by [M, N] const auto c_m0_n0_block_cluster_idx = - cblockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + // TODO: change this. I think it needs multi-dimensional alignment constexpr auto max_lds_align = K1; // TODO: check alignment // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_k0_m0_m1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_k0_n0_n1_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); // TODO: check alignment // A matrix in LDS memory, for blockwise GEMM constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); // TODO: check alignment // B matrix in LDS memory, for blockwise GEMM constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); - static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == a_k0_m_k1_block_desc.GetElementSpaceSize() && - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == b_k0_n_k1_block_desc.GetElementSpaceSize() && "wrong!"); @@ -326,14 +316,14 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(a_k0_m0_m1_k1_grid_desc), - decltype(a_k0_m0_m1_k1_block_desc), + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), ABlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths @@ -341,23 +331,23 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(a_k0_m0_m1_k1_grid_desc, + true>(a_grid_desc_k0_m0_m1_k1, make_multi_index(0, im0, 0, 0), - a_k0_m0_m1_k1_block_desc, + a_block_desc_k0_m0_m1_k1, make_multi_index(0, 0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< BlockSize, InMemoryDataOperationEnum::Set, - Sequence, + Sequence, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(b_k0_n0_n1_k1_grid_desc), - decltype(b_k0_n0_n1_k1_block_desc), + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3>, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths @@ -365,19 +355,19 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder false, - true>(b_k0_n0_n1_k1_grid_desc, + true>(b_grid_desc_k0_n0_n1_k1, make_multi_index(0, in0, 0, 0), - b_k0_n0_n1_k1_block_desc, + b_block_desc_k0_n0_n1_k1, make_multi_index(0, 0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlockM1] is in LDS - // b_mtx[KPerBlocl, NPerBlockN1] is in LDS - // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< BlockSize, FloatAB, FloatAB, @@ -395,58 +385,53 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_m10_m11_n10_n11_thread_desc = make_naive_tensor_descriptor_packed( + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); FloatAB* p_a_block_double = p_shared_block; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output auto c_thread_buf = make_static_buffer( - c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); - ThreadwiseTensorSliceSet_v1{} - .Run(c_m10_m11_n10_n11_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); + // Initialize C + c_thread_buf.Clear(); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, - a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, - b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); } if constexpr(HasMainKBlockLoop) { - const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); index_t k_block_data_begin = 0; @@ -455,82 +440,76 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, - a_block_slice_copy_step, - AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, - b_block_slice_copy_step, - BGridMoveSliceWindowStepHacks{}); - - __syncthreads(); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); // LDS double buffer: GEMM on current data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K0 - 2 * KPerBlock); + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); } // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow( - a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step, AGridMoveSliceWindowStepHacks{}); - b_blockwise_copy.MoveSrcSliceWindow( - b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step, BGridMoveSliceWindowStepHacks{}); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); - __syncthreads(); + block_sync_lds(); // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridStepHacks{}); - b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridStepHacks{}); + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); - __syncthreads(); + block_sync_lds(); // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); } else // if has 1 iteration left { @@ -538,12 +517,12 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); } // output: register to global memory { - constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = make_naive_tensor_descriptor_packed( make_tuple(I1, Number{}, @@ -559,8 +538,9 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ThreadwiseTensorSliceTransfer_v1r3< FloatAcc, FloatC, - decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), - decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, Sequence<1, c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I1], @@ -572,22 +552,21 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, - true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, make_multi_index(im0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], in0, c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], - c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} - .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, - c_m0_m10_m11_n0_n10_n11_grid_desc, - c_grid_buf, - CGridStepHacks{}); + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); } } }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100644 index 0000000000..3ec098486b --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,668 @@ +#pragma once + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v7.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const StaticallyIndexedArray& + ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different + // type from E + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(FloatCShuffle{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 6a1b6eef31..20c3a0b618 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,5 +1,6 @@ #pragma once #include "common_header.hpp" +#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" namespace ck { @@ -248,4 +249,116 @@ struct GridwiseGemmPipeline_v1<2> } }; +template +struct GridwiseGemmPipelineInterwave_v1; + +template <> +struct GridwiseGemmPipelineInterwave_v1<1> +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // block_sync_lds(); // moved into blockwise_gemm + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// Note: 2 stage prefetch not optimized for inter-wave loop scheduler +template <> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +{ +}; + +template +constexpr auto GridwiseGemmPipeline_v1_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index 25aade2f3a..80a6eeace6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -3,9 +3,10 @@ #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" -#include "blockwise_tensor_slice_transfer_v6r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" #include "reduction_functions_threadwise.hpp" @@ -15,11 +16,12 @@ namespace ck { template (p_a_grid, - p_b_grid, - p_c_grid, - p_d0_grid, - p_d1_grid, - p_shared, - a_element_op, - b_element_op, - c_element_op, - d1_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - d_grid_desc_mblock_mperblock, - block_2_ctile_map); + p_b_grid, + p_c_grid, + p_ds_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); #else ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = p_d0_grid; - ignore = p_d1_grid; + ignore = p_ds_grid; ignore = a_element_op; ignore = b_element_op; ignore = c_element_op; - ignore = d1_element_op; + ignore = dxs_in_element_op; + ignore = dxs_out_element_op; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; @@ -88,15 +90,15 @@ template + index_t CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopScheduler LoopSched> struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -216,10 +219,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDesc_M_N& c_grid_desc_m_n) + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { // static_assert(is_known_at_compile_time>::value && // is_known_at_compile_time>::value, @@ -247,21 +252,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return false; } + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; @@ -307,40 +306,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - // FIXME: remove - constexpr auto M01 = I1; - constexpr auto N01 = I1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - auto d0_grid_buf = make_dynamic_buffer( - p_d0_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); - auto d1_grid_buf = make_dynamic_buffer( - p_d1_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); // divide block work by [M, N] const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -403,28 +374,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -434,28 +405,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -473,17 +444,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr index_t KPack = math::max( math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -502,27 +474,30 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); - // shuffle C and write out + // shuffle C + reduction + write out { static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, @@ -636,8 +611,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = BlockwiseTensorSliceTransfer_v6r1< - BlockSize, // index_t BlockSize, + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup CElementwiseOperation, // ElementwiseOperation, CGlobalMemoryDataOperation, // DstInMemOp, Sequence<1, @@ -661,6 +636,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), c_element_op}; + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction // LDS c_reduce_block_desc_mperblock_nperblock constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, @@ -711,16 +709,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto d_reduce_thread_desc_mblock_mperblock = make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); - // TODO: this should be implemented as a blockwise reduction auto c_reduce_thread_buf = make_static_buffer( c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); - auto d0_thread_buf = make_static_buffer( - d_reduce_thread_desc_mperblock.GetElementSpaceSize()); - - auto d1_thread_buf = make_static_buffer( - d_reduce_thread_desc_mperblock.GetElementSpaceSize()); - // reduce: threadwise copy from LDS to VGPR constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); @@ -744,47 +735,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 1, true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; - // reduce: copy from VGPR to global - auto d0_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< - FloatReduceAcc, - FloatD, - decltype(d_reduce_thread_desc_mblock_mperblock), - decltype(d_grid_desc_mblock_mperblock), - ck::tensor_operation::element_wise::PassThrough, - Sequence<1, mreduce_per_thread>, - Sequence<0, 1>, - 1, - CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, - DGlobalMemoryDataOperation, - 1, - false>{d_grid_desc_mblock_mperblock, - make_multi_index(block_work_idx[I0], // mblock - c_reduce_thread_data_idx_begin[I0]), // mperblock - ck::tensor_operation::element_wise::PassThrough{}}; + auto dxs_reduce_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_d_grid = p_ds_grid[I]; + auto d_out_element_op = dxs_out_element_op[I]; - auto d1_reduce_thread_copy_vgpr_to_global = d0_reduce_thread_copy_vgpr_to_global; - - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(d_reduce_thread_desc_mblock_mperblock), + decltype(d_grid_desc_mblock_mperblock), + decltype(d_out_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + DGlobalMemoryDataOperation::At(I), + 1, + false>{d_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + d_out_element_op}; + }, + Number{}); constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -811,64 +784,74 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 c_grid_desc_mblock_mperblock_nblock_nperblock, c_grid_buf); - using ThreadwiseReduce_D0 = - ThreadwiseReduction; - - using ThreadwiseReduce_D1 = - ThreadwiseReduction; - - const auto d0_zeroVal = D0ReduceOperation::GetReductionZeroVal(); - const auto d1_zeroVal = D0ReduceOperation::GetReductionZeroVal(); - - static_for<0, mreduce_per_thread, 1>{}( - [&](auto I) { d0_thread_buf(I) = d0_zeroVal; }); - static_for<0, mreduce_per_thread, 1>{}( - [&](auto I) { d1_thread_buf(I) = d1_zeroVal; }); - - // reduce + // TODO - extract following into reduction_blockwise { - // copy from LDS to VGPR c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, c_shuffle_block_buf, c_reduce_thread_desc_mperblock_nperblock, make_tuple(I0, I0), c_reduce_thread_buf); - // reduce in VGPR - ThreadwiseReduce_D0::Reduce(c_reduce_thread_buf, d0_thread_buf); + static_for<0, p_ds_grid.Size(), 1>{}([&](auto In) { + auto& p_d_grid = p_ds_grid[In]; - static_for<0, mreduce_per_thread, 1>{}([&](auto im) { - static_for<0, nreduce_per_thread, 1>{}([&](auto in) { - constexpr auto offset = - Number{}; + auto d_grid_buf = make_dynamic_buffer( + p_d_grid, d_grid_desc_mblock_mperblock.GetElementSpaceSize()); - d1_element_op(c_reduce_thread_buf(offset), c_reduce_thread_buf(offset)); + auto d_thread_buf = + make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& d_in_element_op = dxs_in_element_op[In]; + + auto& d_reduce_thread_copy_vgpr_to_global = + dxs_reduce_thread_copy_vgpr_to_global(In); + + using DReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto d_identityVal = + DReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { d_thread_buf(I) = d_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + d_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, d_thread_buf); + + // copy from VGPR to Global + d_reduce_thread_copy_vgpr_to_global.Run( + d_reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + d_thread_buf, + d_grid_desc_mblock_mperblock, + d_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + d_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + d_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } }); - - ThreadwiseReduce_D1::Reduce(c_reduce_thread_buf, d1_thread_buf); - - // copy from VGPR to Global - d0_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - d0_thread_buf, - d_grid_desc_mblock_mperblock, - d0_grid_buf); - - d1_reduce_thread_copy_vgpr_to_global.Run(d_reduce_thread_desc_mblock_mperblock, - make_tuple(I0, I0), - d1_thread_buf, - d_grid_desc_mblock_mperblock, - d1_grid_buf); } if constexpr(access_id < num_access - 1) @@ -878,18 +861,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - - // move on D0 - d0_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( - d_grid_desc_mblock_mperblock, - make_tuple(c_global_step[I0], c_global_step[I1])); - - // move on D1 - d1_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( - d_grid_desc_mblock_mperblock, - make_tuple(c_global_step[I0], c_global_step[I1])); } }); + + // Reduction } } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index a9aa53e071..a16db7b78b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -3,6 +3,7 @@ #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_xdlops.hpp" #include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "thread_group_tensor_slice_transfer_v6r1.hpp" @@ -108,7 +109,8 @@ template + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched> struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 { static constexpr auto I0 = Number<0>{}; @@ -194,10 +196,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, - const CGridDesc_M_N& c_grid_desc_m_n) + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, @@ -221,21 +225,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 return false; } + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; @@ -266,40 +264,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - // FIXME: remove - constexpr auto M01 = I1; - constexpr auto N01 = I1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t::selected_mfma.k_per_blk); - auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -450,25 +425,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - blockwise_gemm, - c_thread_buf, - num_k_block_main_loop); + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); // shuffle C and write out { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp new file mode 100644 index 0000000000..b1f3779802 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -0,0 +1,976 @@ +#pragma once + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v6r1.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "gridwise_gemm_pipeline_v1.hpp" + +namespace ck { + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v4_no_carry +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v4_no_carry() = default; + + __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_up_diff, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_up_diff[I0]; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod_wrw, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +__host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths) +{ + return Merge_v4_no_carry{low_lengths}; +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_bwd_weight(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + // M0/M1/M1Padding + static constexpr auto M1PerBlock = Number{}; + static constexpr auto M0PerBlock = Number{}; + static constexpr auto M1Padding = Number{}; + + // N0/N1/N1Padding + static constexpr auto N1PerBlock = Number{}; + static constexpr auto N0PerBlock = Number{}; + static constexpr auto N1Padding = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_block_desc_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + M1Padding), + Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_b_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return a_block_desc_b_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_b_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_block_desc_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + N1Padding), + Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_b_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return b_block_desc_b_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return b_block_desc_b_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = math::integer_least_multiple( + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size + b_block_space_size) * sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + // const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t num_loop = K0 / K0PerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + + // return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + constexpr index_t KPack = + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + void* p_shared = static_cast(p_shared_block); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 7a5af95fd9..fb2af8967f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,10 +1,12 @@ #pragma once + #include "common_header.hpp" #include "multi_index_transform_helper.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" +#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "blockwise_gemm_xdlops.hpp" -#include "blockwise_tensor_slice_transfer_v4r1.hpp" +#include "thread_group_tensor_slice_transfer_v4r1.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "gridwise_gemm_pipeline_v1.hpp" #include "gridwise_gemm_pipeline_v2.hpp" @@ -190,12 +192,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -224,31 +226,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); @@ -292,7 +278,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 }(); using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = @@ -373,6 +331,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -391,28 +357,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -422,28 +388,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -460,7 +426,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr bool CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -264,7 +249,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); using BlockwiseGemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(KBatch), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_kbatch_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_kbatch_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); @@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 const auto block_work_idx = c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0), + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1)))) + { + return; + } + const index_t k_batch_id = block_work_idx[I0]; // HACK: this force m/n_block_data_idx_on_grid into SGPR @@ -422,27 +386,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), a_element_op, @@ -452,27 +416,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), b_element_op, @@ -489,7 +453,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) - { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -254,39 +239,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // return block_id to C matrix tile idx (m0, n0) mapping __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - const auto M = c_m_n_grid_desc.GetLength(I0); - const auto N = c_m_n_grid_desc.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(KBatch), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - - const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); - - return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); } __host__ __device__ static constexpr auto @@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 const auto block_work_idx = c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + const index_t k_batch_id = block_work_idx[I0]; // HACK: this force m/n_block_data_idx_on_grid into SGPR @@ -411,27 +375,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 }(); // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_b_k0_m_k1_grid_desc), - decltype(a_b_k0_m_k1_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_b_k0_m_k1_grid_desc, make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), a_element_op, @@ -441,27 +405,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_b_k0_n_k1_grid_desc), - decltype(b_b_k0_n_k1_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 2, 1, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_b_k0_n_k1_grid_desc, make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), b_element_op, @@ -478,7 +442,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { // static_assert(is_known_at_compile_time>::value && // is_known_at_compile_time>::value, @@ -258,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / KPerBlock; @@ -317,39 +300,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -387,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -405,28 +371,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -436,28 +402,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -476,7 +442,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -269,31 +270,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); @@ -329,40 +314,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = remove_cvref_t, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -462,28 +431,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - NumGemmKPrefetchStage>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -500,7 +469,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01, - index_t N01) + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -273,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 return false; } - // check M01, N01 - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - if(!(M0 % M01 == 0 && N0 % N01 == 0)) + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { return false; + } // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); - - return grid_size; - } - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { const index_t num_loop = K / (K0PerBlock * K1); @@ -333,39 +316,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 } // return block_id to C matrix tile idx (m0, n0) mapping - __host__ __device__ static constexpr auto - MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - const auto M0 = M / M1; - const auto N0 = N / N1; - - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; - - const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); - - const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto cblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return cblockid_to_m0_n0_block_cluster_adaptor; + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); } using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = remove_cvref_t; - template + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -429,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -447,27 +413,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // A matrix blockwise copy auto a_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_grid_desc_k0_m_k1), - decltype(a_block_desc_k0_m_k1), - ABlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( a_grid_desc_k0_m_k1, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, @@ -477,27 +443,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // B matrix blockwise copy auto b_blockwise_copy = - BlockwiseTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0_n_k1), - BBlockTransferSrcAccessOrder, - Sequence<1, 0, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( b_grid_desc_k0_n_k1, make_multi_index(0, n_block_data_idx_on_grid, 0), b_element_op, @@ -514,7 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 // sanity check auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + using PassThroughOp = tensor_operation::element_wise::PassThrough; constexpr auto I0 = Number<0>{}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp new file mode 100644 index 0000000000..de293eed35 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -0,0 +1,407 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GRIDWISE_SOFTMAX_HPP +#define GRIDWISE_SOFTMAX_HPP + +#include "reduction_common.hpp" +#include "reduction_operator.hpp" +#include "reduction_functions_accumulate.hpp" +#include "reduction_functions_blockwise.hpp" +#include "reduction_functions_threadwise.hpp" + +#include "threadwise_tensor_slice_transfer.hpp" +#include "element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, + const GridDesc_M_K out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_k, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); +}; + +template +struct GridwiseSoftmax_mk_to_mk +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (KThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseMaxReduce = PartitionedBlockwiseReduction; // PropagateNan + + using ThreadwiseMaxReduce = ThreadwiseReduction; // PropagateNan + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k, + const GridDesc_M_K& out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer + out_thread_buf; + + StaticBuffer max_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2( + out_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3( + out_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + constexpr auto in_thread_copy_fwd_step = make_multi_index(0, K_BlockTileSize); + constexpr auto in_thread_copy_bwd_step = make_multi_index(0, -K_BlockTileSize); + + /// + /// max(x) + /// + const auto in_global_val_buf_oob_non_zero = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + reduce::Max::template GetIdentityValue()); + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf_oob_non_zero, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + /// + /// sum(exp(x - max(x))) + /// + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + // Normally, 0 as invalid element value is adequate since 0 makes no contribution to + // accumulated result. However, in stable softmax, all values 0s or not are subtracted by + // another value_max. As numbers become non-zero, effectively it allows invalid values to + // slip through and contribute to the accumulated result. + // + // The trick here is leveraging the fact that many math functions (add, sub, exp, ...) + // propagate NaNs when operands have NaNs involved. By initialiing invalid element value + // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still + // be identified as an invalid value. We can then discard the invalid values which + // originally failed the bound check during accumulation. This allows to ignore values that + // failed bound check even after multiple math manipulations. + const auto in_global_val_buf_oob_nan = + make_dynamic_buffer(p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + NumericLimits::QuietNaN()); + + using BlockwiseSumReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Add, + false, // ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseSumReduce = + ThreadwiseReduction>; + + reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf_oob_nan, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + // do element-wise pre-reduction operation + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); + }); + }); + + ThreadwiseSumReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); + // block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + /// + /// softmax + /// + reducedTiles = 0; + if(float_equal_zero{}(beta)) + { + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf_oob_nan, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf_oob_nan, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + threadwise_dst_load.Run(out_grid_desc_m_k, + out_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf); + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * out_thread_buf(Number{}); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + } +}; + +} // namespace ck +#endif // GRIDWISE_SOFTMAX_HPP diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp new file mode 100644 index 0000000000..5773068756 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp @@ -0,0 +1,129 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_unary_elementwise_1d(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) +{ + GridwiseUEltwise::Run(p_a_global, p_b_global, a_grid_desc_m0, b_grid_desc_m0, functor); +} + +template +struct GridwiseUnaryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m0 = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * ScalarPerVector); + } + + __host__ __device__ static constexpr bool CheckValidity(const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0) + { + return a_grid_desc_m0.GetLength(I0) == b_grid_desc_m0.GetLength(I0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(const index_t tensor_size) + { + const index_t grid_size = math::integer_divide_ceil(tensor_size, 256 * ScalarPerVector); + + return grid_size; + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + BDataType* __restrict__ p_b_global, + const GridDesc_M0 a_grid_desc_m0, + const GridDesc_M0 b_grid_desc_m0, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m0.GetElementSpaceSize()); + auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m0.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + ScalarPerVector, + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m0, thread_store_global_offset}; + + auto b_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + ScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + b_grid_desc_m0, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto m0 = b_grid_desc_m0.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * ScalarPerVector; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = m0 / (loop_step); + do + { + // read and process ScalarPerVector elements + a_global_load.Run( + a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); + + static_for<0, ScalarPerVector, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); + functor(b_thread_buf(Number{}), a_thread_buf(Number{})); + }); + + b_global_write.Run(thread_desc_m0, + make_tuple(I0), // SrcSliceOriginIdx + b_thread_buf, + b_grid_desc_m0, + b_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); + b_global_write.MoveDstSliceWindow(b_grid_desc_m0, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp index 3dcfe3a030..35fc1b929d 100644 --- a/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp +++ b/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -39,7 +39,9 @@ template + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithNanCheck> struct ThreadwiseReduction { static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; @@ -51,8 +53,6 @@ struct ThreadwiseReduction static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); - using Accumulation = detail::AccumulateWithNanCheck; - template __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) { @@ -73,12 +73,15 @@ struct ThreadwiseReduction // 2) DstDesc is known at compile-time // 3) SrcBuffer is static buffer // 4) DstBuffer is static buffer -template +template < + typename AccDataType, + typename IndexDataType, + typename SrcThreadDesc_M_K, + typename DstThreadDesc_M, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> struct ThreadwiseReductionWithIndex { static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; @@ -90,9 +93,6 @@ struct ThreadwiseReductionWithIndex static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); - using Accumulation = - detail::AccumulateWithIndexAndNanCheck; - template ::type = false> -struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 { - __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -124,9 +122,9 @@ template ::type = false> -struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && @@ -220,4 +218,3 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 48338ddfa6..f0e9c7e761 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -1,5 +1,4 @@ -#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP -#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP +#pragma once #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }; } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp new file mode 100644 index 0000000000..782e456f3d --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -0,0 +1,295 @@ +#pragma once + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "tensor_space_filling_curve.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = + SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + auto generate_vectors = [&](auto data_types) { + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + }; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(SrcDatas{}); + auto dst_vectors = generate_vectors(DstDatas{}); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + // apply pointwise function + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 9d72abb72e..a39b795818 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -25,6 +25,7 @@ enum struct MfmaInstr mfma_f32_16x16x8bf16, mfma_i32_32x32x8i8, mfma_i32_16x16x16i8, + mfma_f64_16x16x4f64 }; template @@ -383,12 +384,40 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t group_size = 1; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f64_16x16x4f64::Run(a, b, reg_c); + } +}; + template struct MfmaSelector { template static constexpr auto GetMfma(); + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f64_16x16x4f64; + } + template <> static constexpr auto GetMfma() { @@ -661,9 +690,10 @@ struct XdlopsGemm template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "base base_type must be float, half, bfloat16, and int8_t!"); + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 53c24b9a98..1e74120f11 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -6,6 +6,8 @@ namespace ck { template union BufferResource { + __device__ constexpr BufferResource() : content{} {} + // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t content; @@ -258,6 +260,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, @@ -915,6 +925,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ } } +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. @@ -1046,4 +1121,39 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr #endif } +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + } // namespace ck diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 94693f510e..d978d7571a 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -294,5 +294,24 @@ struct intrin_mfma_i32_16x16x16i8<16, 16> } }; +template +struct intrin_mfma_f64_16x16x4f64; + +template <> +struct intrin_mfma_f64_16x16x4f64<16, 16> +{ + template + __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) + { +#ifdef __gfx90a__ + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; } // namespace ck #endif diff --git a/include/ck/utility/common_header.hpp b/include/ck/utility/common_header.hpp index 539263703b..34c0a7821b 100644 --- a/include/ck/utility/common_header.hpp +++ b/include/ck/utility/common_header.hpp @@ -32,7 +32,7 @@ #include "debug.hpp" #include "amd_buffer_addressing.hpp" -#include "generic_memory_space_atomic_add.hpp" +#include "generic_memory_space_atomic.hpp" #include "get_id.hpp" #include "synchronization.hpp" #include "amd_address_space.hpp" diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index bf8dc74f34..a723196539 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,4 +1,5 @@ #pragma once + #include "statically_indexed_array.hpp" namespace ck { @@ -1000,6 +1001,11 @@ struct NumericLimits __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } }; template <> @@ -1008,12 +1014,15 @@ struct NumericLimits static constexpr unsigned short binary_min = 0x0400; static constexpr unsigned short binary_max = 0x7BFF; static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } }; } // namespace ck diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index c00982dfff..0ad78423fe 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -3,7 +3,7 @@ #include "enable_if.hpp" #include "c_style_pointer_cast.hpp" #include "amd_buffer_addressing.hpp" -#include "generic_memory_space_atomic_add.hpp" +#include "generic_memory_space_atomic.hpp" namespace ck { @@ -125,6 +125,10 @@ struct DynamicBuffer { this->template AtomicAdd(i, is_valid_element, x); } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax) + { + this->template AtomicMax(i, is_valid_element, x); + } else if constexpr(Op == InMemoryDataOperationEnum::Add) { auto tmp = this->template Get(i, is_valid_element); @@ -326,6 +330,42 @@ struct DynamicBuffer } } + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } diff --git a/include/ck/utility/enable_if.hpp b/include/ck/utility/enable_if.hpp index 501e1bfc1c..db54f25aa0 100644 --- a/include/ck/utility/enable_if.hpp +++ b/include/ck/utility/enable_if.hpp @@ -1,5 +1,4 @@ -#ifndef CK_ENABLE_IF_HPP -#define CK_ENABLE_IF_HPP +#pragma once namespace ck { @@ -10,4 +9,3 @@ template using enable_if_t = typename std::enable_if::type; } // namespace ck -#endif diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp new file mode 100644 index 0000000000..1a2dacb5c5 --- /dev/null +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -0,0 +1,120 @@ +#pragma once +#include "data_type.hpp" + +namespace ck { + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +__device__ X atomic_add(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float atomic_add(float* p_dst, const float& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ double atomic_add(double* p_dst, const double& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_max explicit for +// each datatype. + +template +__device__ X atomic_max(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_max(int32_t* p_dst, const int32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ uint32_t atomic_max(uint32_t* p_dst, const uint32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float atomic_max(float* p_dst, const float& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ double atomic_max(double* p_dst, const double& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float2_t atomic_max(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicMax(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicMax(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +} // namespace ck diff --git a/include/ck/utility/generic_memory_space_atomic_add.hpp b/include/ck/utility/generic_memory_space_atomic_add.hpp deleted file mode 100644 index 8ee2081776..0000000000 --- a/include/ck/utility/generic_memory_space_atomic_add.hpp +++ /dev/null @@ -1,44 +0,0 @@ -#pragma once -#include "data_type.hpp" - -namespace ck { - -template -__device__ X atomic_add(X* p_dst, const X& x); - -template <> -__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) -{ - return atomicAdd(p_dst, x); -} - -template <> -__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) -{ - return atomicAdd(p_dst, x); -} - -template <> -__device__ float atomic_add(float* p_dst, const float& x) -{ - return atomicAdd(p_dst, x); -} - -template <> -__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - const vector_type vx{x}; - vector_type vy{0}; - - vy.template AsType()(I0) = - atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); - vy.template AsType()(I1) = - atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); - - return vy.template AsType()[I0]; -} - -} // namespace ck diff --git a/include/ck/utility/get_id.hpp b/include/ck/utility/get_id.hpp index 14938081e4..7c62b890c7 100644 --- a/include/ck/utility/get_id.hpp +++ b/include/ck/utility/get_id.hpp @@ -4,16 +4,21 @@ namespace ck { __host__ __device__ constexpr index_t get_warp_size() -{ // warpSize is defined by HIP +{ + // warpSize is defined by HIP return warpSize; } __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + __device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } __device__ index_t get_block_1d_id() { return blockIdx.x; } __device__ index_t get_grid_size() { return gridDim.x; } +__device__ index_t get_block_size() { return blockDim.x; } + } // namespace ck diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 3071e45640..59fe17e867 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -1,6 +1,4 @@ -#ifndef CK_INNER_PRODUCT_HPP -#define CK_INNER_PRODUCT_HPP - +#pragma once #include "data_type.hpp" namespace ck { @@ -138,7 +136,7 @@ template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) { -#if defined(CK_USE_DOT4_I32_I8) +#if defined(CK_USE_AMD_V_DOT4_I32_I8) #if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM asm volatile("\n \ v_dot4_i32_i8 %0, %1, %2, %0\n \ @@ -202,4 +200,3 @@ inner_product(const int8x16_t& a, const int8x16_t } } // namespace ck -#endif diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index 48438e6179..e7724a40c8 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -142,6 +142,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys) return min(x, min(ys...)); } +// disallow implicit type casting +template +__device__ T exp(T x); + +template <> +__device__ float exp(float x) +{ + return __expf(x); +} + +template <> +__device__ double exp(double x) +{ + return exp(x); +} + // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 572d576e7a..438f5e12bd 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -3,11 +3,13 @@ #include #include "data_type.hpp" -#include "half.hpp" +#include "type.hpp" namespace ck { namespace math { +// math functions for the host, some are implemented by calling C++ std functions + static inline __host__ float abs(float x) { return std::abs(x); }; static inline __host__ double abs(double x) { return std::abs(x); }; @@ -28,26 +30,26 @@ static inline __host__ int32_t abs(int32_t x) static inline __host__ half_t abs(half_t x) { - half_float::half xx = *reinterpret_cast(&x); + uint16_t xx = ck::bit_cast(x); - half_float::half abs_xx = half_float::abs(xx); + uint16_t abs_xx = xx & 0x7fff; - half_t abs_x = *reinterpret_cast(&abs_xx); + half_t abs_x = ck::bit_cast(abs_xx); return abs_x; }; -static inline __host__ float isnan(float x) { return std::isnan(x); }; +static inline __host__ bool isnan(float x) { return std::isnan(x); }; -static inline __host__ double isnan(double x) { return std::isnan(x); }; +static inline __host__ bool isnan(double x) { return std::isnan(x); }; -static inline __host__ int8_t isnan(int8_t x) +static inline __host__ bool isnan(int8_t x) { (void)x; return false; }; -static inline __host__ int32_t isnan(int32_t x) +static inline __host__ bool isnan(int32_t x) { (void)x; return false; @@ -55,11 +57,59 @@ static inline __host__ int32_t isnan(int32_t x) static inline __host__ bool isnan(half_t x) { - half_float::half xx = *reinterpret_cast(&x); + uint16_t xx = ck::bit_cast(x); - return half_float::isnan(xx); + return (xx & 0x7FFF) > 0x7C00; }; +static inline __host__ float sqrt(float x) { return std::sqrt(x); }; + +static inline __host__ double sqrt(double x) { return std::sqrt(x); }; + +// math functions for the HIP kernel, some are implemented by calling hip builtin functions + +static inline __device__ float abs(float x) { return ::abs(x); }; + +static inline __device__ double abs(double x) { return ::abs(x); }; + +static inline __device__ int8_t abs(int8_t x) +{ + int8_t sgn = x >> (8 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __device__ int32_t abs(int32_t x) +{ + int32_t sgn = x >> (32 - 1); + + return (x ^ sgn) - sgn; +}; + +static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; + +static inline __device__ bool isnan(float x) { return ::isnan(x); }; + +static inline __device__ bool isnan(double x) { return ::isnan(x); }; + +static inline __device__ bool isnan(int8_t x) +{ + (void)x; + return false; +}; + +static inline __device__ bool isnan(int32_t x) +{ + (void)x; + return false; +}; + +static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; + +static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; + +static inline __device__ double sqrt(double x) { return ::sqrt(x); }; + } // namespace math } // namespace ck diff --git a/include/ck/utility/number.hpp b/include/ck/utility/number.hpp index 6f262a4d9f..97a71f8a41 100644 --- a/include/ck/utility/number.hpp +++ b/include/ck/utility/number.hpp @@ -8,5 +8,8 @@ namespace ck { template using Number = integral_constant; +template +using LongNumber = integral_constant; + } // namespace ck #endif diff --git a/include/ck/utility/reduction_functions_accumulate.hpp b/include/ck/utility/reduction_functions_accumulate.hpp index 4e8636e5b2..05ce9b16ce 100644 --- a/include/ck/utility/reduction_functions_accumulate.hpp +++ b/include/ck/utility/reduction_functions_accumulate.hpp @@ -27,6 +27,7 @@ #define CK_REDUCTION_FUNCTIONS_BINOP_HPP #include "data_type.hpp" +#include "math_v2.hpp" #include "reduction_common.hpp" #include "reduction_operator.hpp" @@ -34,37 +35,46 @@ namespace ck { namespace detail { -template -static inline __device__ bool is_nan(T x) +// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs) +template +struct AccumulateWithNanIgnore { - return (isnan(x)); -}; - -template <> -inline __device__ bool is_nan(half_t x) -{ - return (__hisnan(x)); + __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + { + if(!isnan(currVal)) + { + ReduceOperation{}(accuVal, currVal); + } + }; }; template struct AccumulateWithNanCheck; +// Does not check for NaN; does not guarantee NaNs be propagated to result +// e.g., given that max(a, b) = a > b ? a : b +// then max(NaN, 1) returns 1 +// max(1, NaN) returns NaN +// since any comparison involving NaNs returns false template struct AccumulateWithNanCheck { // cppcheck-suppress constParameter - __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) { ReduceOperation{}(accuVal, currVal); }; }; +// Check for NaN; guarantees NaNs be propagated to result template struct AccumulateWithNanCheck { - __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) + __host__ __device__ static inline void Calculate(AccDataType& accuVal, AccDataType currVal) { - if(is_nan(currVal)) + using ck::math::isnan; + + if(isnan(currVal)) { accuVal = currVal; } @@ -81,7 +91,7 @@ struct AccumulateWithIndexAndNanCheck; template struct AccumulateWithIndexAndNanCheck { - __device__ static inline void + __host__ __device__ static inline void // cppcheck-suppress constParameter Calculate(AccDataType& accuVal, AccDataType currVal, @@ -101,12 +111,14 @@ template { // The method is called when the ReduceOperation is indexable and the user asked for indices - __device__ static inline void Calculate(AccDataType& accuVal, - AccDataType currVal, - IndexDataType& accuIndex, - IndexDataType currIndex) + __host__ __device__ static inline void Calculate(AccDataType& accuVal, + AccDataType currVal, + IndexDataType& accuIndex, + IndexDataType currIndex) { - if(is_nan(currVal)) + using ck::math::isnan; + + if(isnan(currVal)) { accuVal = currVal; accuIndex = currIndex; diff --git a/include/ck/utility/reduction_operator.hpp b/include/ck/utility/reduction_operator.hpp index 5893f60547..eccdf932d7 100644 --- a/include/ck/utility/reduction_operator.hpp +++ b/include/ck/utility/reduction_operator.hpp @@ -26,7 +26,9 @@ #ifndef CK_REDUCTION_OPERATOR_HPP #define CK_REDUCTION_OPERATOR_HPP -#include "common_header.hpp" +#include "config.hpp" +#include "data_type.hpp" +#include "type.hpp" namespace ck { @@ -35,18 +37,16 @@ namespace reduce { // Every binary operator used in reduction is represented by a templated functor class. Each functor // class must provide at least // three members: -// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary +// 1) GetIdentityValue() -- the interface to return the "identity element" for the binary // operator, "identity element" is the unique // element in the algebraic space that doesn't affect the value of other elements // when operated against them, and the concept is similar to zero vector in // vector space // (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). -// 2) indexable -- boolean value indicating whether indices of the operated elements could be -// recorded. Usually, Min/Max operator could -// need to record the indices of elements. For operator like Add/Mul, no need to -// record the indices. -// 3) operator() -- the first argument of the operator must be both an input & output, and the -// corresponding variable usually stores +// 2) IsCompatibleInMemoryDataOperation() -- return true if the reduction task corresponding to this +// operator can use the InMemoryDataOperation to finalize, or else it return false 3) operator() -- +// the first argument of the operator must be both an input & output, and the corresponding variable +// usually stores // the accumulated result of many operator() calls; the second argument is only an // input. For indexable binary // operator, the second version of operator() has third argument (which is an @@ -55,44 +55,92 @@ namespace reduce { // accumulated index also need be // changed. -template struct Add { - using dataType = T; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; - __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::AtomicAdd || + operation == InMemoryDataOperationEnum::Set; + }; - __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Add accumulator!"); + + a = a + b; + } }; -template struct Mul { - using dataType = T; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(1.0f); + }; - __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + return operation == InMemoryDataOperationEnum::Set; + }; - __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } + template + __host__ __device__ inline constexpr void operator()(T& a, T b) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Mul accumulator!"); + + a = a * b; + } }; -template struct Max { - using dataType = T; - - __host__ __device__ static constexpr T GetReductionZeroVal() + template + __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits::Lowest(); }; + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + if(a < b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Max accumulator!"); + if(a < b) { a = b; @@ -101,24 +149,41 @@ struct Max } }; -template struct Min { - using dataType = T; - - __host__ __device__ static constexpr T GetReductionZeroVal() + template + __host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits::Max(); }; + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_min to be added + return operation == InMemoryDataOperationEnum::Set; + }; + + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + if(a > b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the Min accumulator!"); + if(a > b) { a = b; @@ -127,21 +192,41 @@ struct Min } }; -template struct AMax { - using dataType = T; + template + __host__ __device__ static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; - __host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + __host__ __device__ static constexpr bool + IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) + { + // ToChange: atomic_max to be added + return operation == InMemoryDataOperationEnum::Set; + }; + template __host__ __device__ inline constexpr void operator()(T& a, T b) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + if(a < b) a = b; } + template __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "The data type is not supported by the AMax accumulator!"); + if(a < b) { a = b; @@ -150,6 +235,55 @@ struct AMax } }; +template +constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation) +{ + T result = ck::type_convert(0.0f); + + if(operation == InMemoryDataOperationEnum::AtomicMax) + result = ck::NumericLimits::Lowest(); + + return (result); +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = false; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value; +}; + +template +struct InMemoryDataOperatonSupportedOnDataType +{ + static constexpr bool value = + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value; +}; + }; // end of namespace reduce } // end of namespace ck diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index c2adfc5063..da0fa50bf3 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -1,5 +1,4 @@ -#ifndef CK_SEQUENCE_HPP -#define CK_SEQUENCE_HPP +#pragma once #include "integral_constant.hpp" #include "type.hpp" @@ -241,7 +240,13 @@ struct arithmetic_sequence_gen } }; - using type = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = Sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename conditional::type; }; // uniform sequence @@ -882,5 +887,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f) return flag; } +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + } // namespace ck -#endif diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index f36328fa5f..ef177e9697 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -36,6 +36,11 @@ struct StaticBuffer : public StaticallyIndexedArray { return base::operator()(i); } + + __host__ __device__ void Clear() + { + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + } }; // static buffer for vector @@ -146,9 +151,9 @@ struct StaticBufferTupleOfVector __host__ __device__ void Clear() { - const index_t numScalars = NumOfVector * ScalarPerVector; + constexpr index_t NumScalars = NumOfVector * ScalarPerVector; - static_for<0, Number{}, 1>{}([&](auto i) { SetAsType(i, S{0}); }); + static_for<0, NumScalars, 1>{}([&](auto i) { SetAsType(i, S{0}); }); } }; @@ -158,5 +163,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number) return StaticBuffer{}; } +template +__host__ __device__ constexpr auto make_static_buffer(LongNumber) +{ + return StaticBuffer{}; +} + } // namespace ck #endif diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index 9e96f06d73..e0ee9d04fd 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) return r; } +// MultiIndex = MultiIndex * index_t +template +__host__ __device__ constexpr auto operator*(const Tuple& x, index_t a) +{ + return a * x; +} + template __host__ __device__ void print_multi_index(const Tuple& x) { diff --git a/include/ck/utility/tuple.hpp b/include/ck/utility/tuple.hpp index 9fa77d1932..f0cb440045 100644 --- a/include/ck/utility/tuple.hpp +++ b/include/ck/utility/tuple.hpp @@ -1,5 +1,4 @@ -#ifndef CK_TUPLE_HPP -#define CK_TUPLE_HPP +#pragma once #include "integral_constant.hpp" #include "sequence.hpp" @@ -17,14 +16,18 @@ struct TupleElementKey }; template -struct TupleElement +struct TupleElementKeyData { - __host__ __device__ constexpr TupleElement() = default; +#if 0 // workaround compiler complaint about implicitly-deleted default constructor + __host__ __device__ constexpr TupleElementKeyData() = default; +#else + __host__ __device__ constexpr TupleElementKeyData() : mData{} {} +#endif template , TupleElement>::value, + typename enable_if, TupleElementKeyData>::value, bool>::type = false> - __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) + __host__ __device__ constexpr TupleElementKeyData(T&& v) : mData(std::forward(v)) { } @@ -32,20 +35,21 @@ struct TupleElement }; template -__host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) +__host__ __device__ constexpr const Data& +get_tuple_element_data(const TupleElementKeyData& x) { return static_cast(x.mData); } template -__host__ __device__ constexpr Data& get_tuple_element(TupleElement& x) +__host__ __device__ constexpr Data& get_tuple_element_data(TupleElementKeyData& x) { return x.mData; } // TODO: not sure the use of reference is correct template -__host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) +__host__ __device__ constexpr Data&& get_tuple_element_data(TupleElementKeyData&& x) { return static_cast(x.mData); } @@ -54,7 +58,7 @@ template struct TupleImpl; template -struct TupleImpl, Xs...> : TupleElement, Xs>... +struct TupleImpl, Xs...> : TupleElementKeyData, Xs>... { __host__ __device__ constexpr TupleImpl() = default; @@ -63,13 +67,13 @@ struct TupleImpl, Xs...> : TupleElement, Xs> !is_same, TupleImpl>::value, bool>::type = false> __host__ __device__ constexpr TupleImpl(Y&& y) - : TupleElement, Xs>(std::forward(y))... + : TupleElementKeyData, Xs>(std::forward(y))... { } template = 2, bool>::type = false> __host__ __device__ constexpr TupleImpl(Ys&&... ys) - : TupleElement, Xs>(std::forward(ys))... + : TupleElementKeyData, Xs>(std::forward(ys))... { static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), "wrong! inconsistent size"); @@ -78,15 +82,15 @@ struct TupleImpl, Xs...> : TupleElement, Xs> __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } template - __host__ __device__ constexpr const auto& GetElementByKey(TupleElementKey) const + __host__ __device__ constexpr const auto& GetElementDataByKey(TupleElementKey) const { - return get_tuple_element>(*this); + return get_tuple_element_data>(*this); } template - __host__ __device__ constexpr auto& GetElementByKey(TupleElementKey) + __host__ __device__ constexpr auto& GetElementDataByKey(TupleElementKey) { - return get_tuple_element>(*this); + return get_tuple_element_data>(*this); } }; @@ -101,8 +105,7 @@ struct Tuple : detail::TupleImpl, Tuple>::value, + typename enable_if, Tuple>::value, bool>::type = false> __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { @@ -122,7 +125,7 @@ struct Tuple : detail::TupleImpl) const { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } // write access @@ -130,7 +133,7 @@ struct Tuple : detail::TupleImpl) { static_assert(I < base::Size(), "wrong! out of range"); - return base::GetElementByKey(detail::TupleElementKey{}); + return base::GetElementDataByKey(detail::TupleElementKey{}); } // read access @@ -160,6 +163,31 @@ struct Tuple : detail::TupleImpl +struct Tuple<> +{ + __host__ __device__ constexpr Tuple() = default; + + __host__ __device__ static constexpr index_t Size() { return 0; } + + template + __host__ __device__ constexpr auto operator=(const T&) + { + return *this; + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } +}; + +template +struct tuple_element +{ + using type = decltype(TTuple{}.At(Number{})); +}; + +template +using tuple_element_t = typename tuple_element::type; + template __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { @@ -174,4 +202,3 @@ constexpr Tuple tie(Args&... args) noexcept } } // namespace ck -#endif diff --git a/include/ck/utility/tuple_helper.hpp b/include/ck/utility/tuple_helper.hpp index 4e5b9cf97c..e7b17ca6a9 100644 --- a/include/ck/utility/tuple_helper.hpp +++ b/include/ck/utility/tuple_helper.hpp @@ -1,5 +1,4 @@ -#ifndef CK_TUPLE_HELPER_HPP -#define CK_TUPLE_HELPER_HPP +#pragma once #include "functional4.hpp" #include "tuple.hpp" @@ -20,6 +19,17 @@ __host__ __device__ constexpr auto generate_tie(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple& tx, + const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return Tuple{std::forward(zs)...}; }, + tx, + ty); +} + namespace detail { template @@ -66,4 +76,3 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, } } // namespace ck -#endif diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index e212c82232..ee3189ebe5 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv::type; template using remove_cvref_t = remove_cv_t>; +template +using remove_pointer_t = typename std::remove_pointer::type; + template inline constexpr bool is_pointer_v = std::is_pointer::value; diff --git a/library/include/ck/library/host/host_interface.hpp b/library/include/ck/library/host/host_interface.hpp new file mode 100644 index 0000000000..955da0f4be --- /dev/null +++ b/library/include/ck/library/host/host_interface.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "stream_config.hpp" +#include "config.hpp" +#include "device_base.hpp" + +struct DeviceConvFwdPtr_t +{ + using BaseArgument = ck::tensor_operation::device::BaseArgument; + using BaseInvoker = ck::tensor_operation::device::BaseInvoker; + + struct DeviceConvFwdPtrImpl; + std::unique_ptr pImpl; + DeviceConvFwdPtr_t(); + ~DeviceConvFwdPtr_t(); + DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&); + DeviceConvFwdPtr_t(DeviceConvFwdPtrImpl&); + DeviceConvFwdPtr_t& operator=(DeviceConvFwdPtr_t&) = delete; + DeviceConvFwdPtr_t& operator=(const DeviceConvFwdPtr_t&) = delete; + std::unique_ptr + MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + const; // in,wei and out element ops are ignored for now since even if we change them, they + // cant be linked + std::unique_ptr + MakeInvokerPointer() const; // requires including BaseInvoker headers + std::string GetTypeString(); + bool IsSupportedArgument(const BaseArgument* arg_ptr); +}; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( + std::vector& instances); diff --git a/library/include/ck/library/host_tensor/device.hpp b/library/include/ck/library/host_tensor/device.hpp index f33b8d4f40..990d2f98b3 100644 --- a/library/include/ck/library/host_tensor/device.hpp +++ b/library/include/ck/library/host_tensor/device.hpp @@ -1,12 +1,34 @@ -#ifndef DEVICE_HPP -#define DEVICE_HPP +#pragma once #include #include #include #include -#include "hip/hip_runtime.h" -#include "hip/hip_fp16.h" +#include +#include + +#include "stream_config.hpp" +#include "ck/options.hpp" + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << __FILE__ << ": " << __LINE__ + << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} struct DeviceMem { @@ -17,6 +39,16 @@ struct DeviceMem void ToDevice(const void* p); void FromDevice(void* p); void SetZero(); + template + void SetValue(T x) + { + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); + } ~DeviceMem(); void* mpDeviceBuf; @@ -36,49 +68,56 @@ struct KernelTimer std::unique_ptr impl; }; -using device_stream_t = hipStream_t; - template -void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) { - hipStream_t stream_id = nullptr; - - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); -} - -template -float launch_and_time_kernel( - F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) -{ - KernelTimer timer; - - printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", - __func__, - grid_dim.x, - grid_dim.y, - grid_dim.z, - block_dim.x, - block_dim.y, - block_dim.z); - - printf("Warm up\n"); - - hipStream_t stream_id = nullptr; - - // warm up - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); - - printf("Start running %d times...\n", nrepeat); - - timer.Start(); - - for(int i = 0; i < nrepeat; ++i) +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) { - hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + const int nrepeat = 10; + + printf("Warm up 1 time\n"); + + // warm up + kernel<<>>(args...); + + printf("Start running %d times...\n", nrepeat); + + KernelTimer timer; + timer.Start(); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + } + + timer.End(); + + return timer.GetElapsedTime() / nrepeat; } + else + { + kernel<<>>(args...); - timer.End(); + return 0; + } +#else + kernel<<>>(args...); - return timer.GetElapsedTime() / nrepeat; -} + return 0; #endif +} diff --git a/library/include/ck/library/host_tensor/host_common_util.hpp b/library/include/ck/library/host_tensor/host_common_util.hpp new file mode 100644 index 0000000000..8fc1d36430 --- /dev/null +++ b/library/include/ck/library/host_tensor/host_common_util.hpp @@ -0,0 +1,102 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_HOST_COMMON_UTIL_HPP +#define GUARD_HOST_COMMON_UTIL_HPP + +#include +#include +#include +#include + +#include "config.hpp" + +namespace ck { + +namespace host_common { + +template +static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + std::cout << "Write output to file " << fileName << std::endl; + } + else + { + std::cout << "Could not open file " << fileName << " for writing" << std::endl; + } +}; + +template +static inline T getSingleValueFromString(const std::string& valueStr) +{ + std::istringstream iss(valueStr); + + T val; + + iss >> val; + + return (val); +}; + +template +static inline std::vector getTypeValuesFromString(const char* cstr_values) +{ + std::string valuesStr(cstr_values); + + std::vector values; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = valuesStr.find(',', pos); + while(new_pos != std::string::npos) + { + const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); + + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + pos = new_pos + 1; + new_pos = valuesStr.find(',', pos); + }; + + std::string sliceStr = valuesStr.substr(pos); + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + return (values); +} + +}; // namespace host_common + +}; // namespace ck + +#endif diff --git a/library/include/ck/library/host_tensor/host_reduce_util.hpp b/library/include/ck/library/host_tensor/host_reduce_util.hpp deleted file mode 100644 index 53e17bcb5c..0000000000 --- a/library/include/ck/library/host_tensor/host_reduce_util.hpp +++ /dev/null @@ -1,269 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef GUARD_HOST_REDUCE_UTIL_HPP -#define GUARD_HOST_REDUCE_UTIL_HPP - -#include -#include -#include -#include -#include - -#include "reduction_enums.hpp" -#include "data_type.hpp" -#include "math_v2.hpp" - -namespace ck { - -namespace host_reduce { - -using ck::NanPropagation; -using ck::ReduceTensorOp; - -template -__host__ static inline std::function PreUnaryOpFn(int) -{ - using ck::math::abs; - - if constexpr(ReduceOpId == ReduceTensorOp::NORM1) - { - return ([&](AccDataType& a_) { a_ = abs(a_); }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::NORM2) - { - return ([&](AccDataType& a_) { a_ = a_ * a_; }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::AMAX) - { - return ([&](AccDataType& a_) { a_ = abs(a_); }); - } - else - { - // ReduceTensorOp::AVG: - // ReduceTensorOp::ADD: - // ReduceTensorOp::MUL: - // ReduceTensorOp::MIN: - // ReduceTensorOp::MAX: - return ([&](AccDataType&) {}); - }; -}; - -template -__host__ static inline std::function PosUnaryOpFn(int32_t divider) -{ - using std::sqrt; - - if constexpr(ReduceOpId == ReduceTensorOp::NORM2) - { - return ([&](AccDataType& a_) { a_ = sqrt(a_); }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::AVG) - { - return ([&, divider](AccDataType& a_) { - a_ = a_ / static_cast(static_cast(divider)); - }); - } - else - { - // ReduceTensorOp::ADD: - // ReduceTensorOp::NORM1: - // ReduceTensorOp::MUL: - // ReduceTensorOp::MIN: - // ReduceTensorOp::MAX: - // ReduceTensorOp::AMAX: - return ([&](AccDataType&) {}); - } -}; - -template -__host__ static inline std::function ReduceOpFn() -{ - if constexpr(ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::AVG || - ReduceOpId == ReduceTensorOp::NORM1 || ReduceOpId == ReduceTensorOp::NORM2) - { - return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ + b_; }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MUL) - { - return ([&](AccDataType& a_, AccDataType b_) { a_ = a_ * b_; }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MIN) - { - return ([&](AccDataType& a_, AccDataType b_) { - if(a_ > b_) - a_ = b_; - }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX) - { - return ([&](AccDataType& a_, AccDataType b_) { - if(a_ < b_) - a_ = b_; - }); - } -}; - -template -__host__ static inline std::function ReduceOpFn2() -{ - if constexpr(ReduceOpId == ReduceTensorOp::MIN) - { - return ([&](AccDataType& a_, AccDataType b_, bool& changed) { - if(a_ > b_) - { - a_ = b_; - changed = true; - } - else - changed = false; - }); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX) - { - return ([&](AccDataType& a_, AccDataType b_, bool& changed) { - if(a_ < b_) - { - a_ = b_; - changed = true; - } - else - changed = false; - }); - } - else - { - // ReduceTensorOp::ADD: - // ReduceTensorOp::MUL: - // ReduceTensorOp::AVG: - // ReduceTensorOp::NORM1: - // ReduceTensorOp::NORM2: - return (std::function{}); - }; -}; - -template -__host__ static inline AccDataType ReduceOpZeroVal() -{ - if constexpr(ReduceOpId == ReduceTensorOp::MUL) - { - return (static_cast(1.0f)); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MIN) - { - return (ck::NumericLimits::Max()); - } - else if constexpr(ReduceOpId == ReduceTensorOp::MAX) - { - return (ck::NumericLimits::Lowest()); - } - else if constexpr(ReduceOpId == ReduceTensorOp::AMAX) - { - return (static_cast(0.0f)); - } - else - { - // ReduceTensorOp::ADD - // ReduceTensorOp::AVG - // ReduceTensorOp::NORM1 - // ReduceTensorOp::NORM2 - return (static_cast(0.0f)); - }; -}; - -template -__host__ static inline void -binop_with_nan_check(std::function opReduce, - AccDataType& accuVal, - AccDataType currVal) -{ - using ck::math::isnan; - - if constexpr(!PropagateNan) - { - opReduce(accuVal, currVal); - } - else - { - if(isnan(currVal)) - accuVal = currVal; - else - opReduce(accuVal, currVal); - }; -}; - -template -__host__ static inline void -binop_with_nan_check2(std::function opReduce, - AccDataType& accuVal, - AccDataType currVal, - int& accuIndex, - int currIndex) -{ - using ck::math::isnan; - - if constexpr(!PropagateNan) - { - bool changed; - - opReduce(accuVal, currVal, changed); - - if(changed) - accuIndex = currIndex; - } - else - { - if(isnan(currVal)) - { - accuVal = currVal; - accuIndex = currIndex; - } - else - { - bool changed; - - opReduce(accuVal, currVal, changed); - - if(changed) - accuIndex = currIndex; - }; - }; -}; - -}; // namespace host_reduce - -static inline std::vector to_int_vector(const std::vector& inData) -{ - std::vector outData; - - for(auto elem : inData) - outData.push_back(static_cast(elem)); - - return (outData); -}; - -}; // namespace ck - -#endif diff --git a/library/include/ck/library/host_tensor/host_reduction.hpp b/library/include/ck/library/host_tensor/host_reduction.hpp index 786d34b73a..6c7162f067 100644 --- a/library/include/ck/library/host_tensor/host_reduction.hpp +++ b/library/include/ck/library/host_tensor/host_reduction.hpp @@ -33,9 +33,10 @@ #include "reduction_enums.hpp" #include "reduction_common.hpp" -#include "host_reduce_util.hpp" +#include "host_common_util.hpp" #include "host_tensor.hpp" #include "data_type.hpp" +#include "reduction_functions_accumulate.hpp" template static void get_all_indexes(const std::array& dimLengths, @@ -105,11 +106,13 @@ static size_t get_offset_from_index(const std::vector& strides, template + bool OutputIndex> struct ReductionHost { using IndexDataType = int32_t; @@ -121,8 +124,6 @@ struct ReductionHost std::vector reduceDims; IndexDataType divider; - std::function preUnaryOp; - std::function posUnaryOp; std::array reduceLengths; std::array reduceStrides; std::array invariantLengths; @@ -136,9 +137,6 @@ struct ReductionHost const std::vector& invariantDims_, const std::vector& reduceDims_) { - using ck::host_reduce::PosUnaryOpFn; - using ck::host_reduce::PreUnaryOpFn; - // this->outLengths = to_int_vector(outDesc.GetLengths()); this->outStrides = outDesc.GetStrides(); @@ -170,24 +168,24 @@ struct ReductionHost invariant_dim_indexes.clear(); get_all_indexes(invariantLengths, invariant_dim_indexes); }; - - preUnaryOp = PreUnaryOpFn(divider); - posUnaryOp = PosUnaryOpFn(divider); }; void Run(float alpha, const InDataType* in_data, float beta, OutDataType* out_data, - IndexDataType* out_indices) + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { - if constexpr(NeedIndices) + if constexpr(OutputIndex) { - RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); + RunImpl_with_index( + alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op); } else { - RunImpl_no_index(alpha, in_data, beta, out_data); + RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op); }; }; @@ -195,38 +193,39 @@ struct ReductionHost const InDataType* in_data, float beta, OutDataType* out_data, - IndexDataType* out_indices) + IndexDataType* out_indices, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; using ck::type_convert; - using ck::host_reduce::binop_with_nan_check2; - using ck::host_reduce::ReduceOpFn2; - using ck::host_reduce::ReduceOpZeroVal; - auto opReduce2 = ReduceOpFn2(); + using Accumulation = ck::detail::AccumulateWithIndexAndNanCheck; if constexpr(NumInvariantDim == 0) { - AccDataType accuVal = ReduceOpZeroVal(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; - for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); auto currVal = type_convert(in_data[offset_reduce]); - preUnaryOp(currVal); + in_elementwise_op(currVal, currVal); - auto currIndex = i; + auto currIndex = static_cast(i); - binop_with_nan_check2( - opReduce2, accuVal, currVal, accuIndex, currIndex); + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); }; - posUnaryOp(accuVal); + acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); @@ -240,13 +239,13 @@ struct ReductionHost else { auto thread_reduce_func = [&](auto invariant_index) { - AccDataType accuVal = ReduceOpZeroVal(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); IndexDataType accuIndex = 0; auto offset_invariant = get_offset_from_index(invariantStrides, invariant_index); - for(IndexDataType i = 0; i < reduce_dim_indexes.size(); i++) + for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) { auto offset_reduce = get_offset_from_index(reduceStrides, reduce_dim_indexes[i]); @@ -254,15 +253,14 @@ struct ReductionHost auto currVal = type_convert(in_data[offset_invariant + offset_reduce]); - preUnaryOp(currVal); + in_elementwise_op(currVal, currVal); - auto currIndex = i; + auto currIndex = static_cast(i); - binop_with_nan_check2( - opReduce2, accuVal, currVal, accuIndex, currIndex); + Accumulation::Calculate(accuVal, currVal, accuIndex, currIndex); }; - posUnaryOp(accuVal); + acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); @@ -302,20 +300,23 @@ struct ReductionHost }; }; - void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) + void RunImpl_no_index(float alpha, + const InDataType* in_data, + float beta, + OutDataType* out_data, + InElementwiseOperation in_elementwise_op, + AccElementwiseOperation acc_elementwise_op) { using ck::float_equal_one; using ck::float_equal_zero; using ck::type_convert; - using ck::host_reduce::binop_with_nan_check; - using ck::host_reduce::ReduceOpFn; - using ck::host_reduce::ReduceOpZeroVal; - auto opReduce = ReduceOpFn(); + using Accumulation = + ck::detail::AccumulateWithNanCheck; if constexpr(NumInvariantDim == 0) { - AccDataType accuVal = ReduceOpZeroVal(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); for(const auto& reduce_index : reduce_dim_indexes) { @@ -324,12 +325,12 @@ struct ReductionHost auto currVal = type_convert(in_data[offset_reduce]); - preUnaryOp(currVal); + in_elementwise_op(currVal, currVal); - binop_with_nan_check(opReduce, accuVal, currVal); + Accumulation::Calculate(accuVal, currVal); }; - posUnaryOp(accuVal); + acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); @@ -342,7 +343,7 @@ struct ReductionHost else { auto thread_reduce_func = [&](auto invariant_index) { - AccDataType accuVal = ReduceOpZeroVal(); + AccDataType accuVal = ReduceOperation::template GetIdentityValue(); auto offset_invariant = get_offset_from_index(invariantStrides, invariant_index); @@ -355,12 +356,12 @@ struct ReductionHost auto currVal = type_convert(in_data[offset_invariant + offset_reduce]); - preUnaryOp(currVal); + in_elementwise_op(currVal, currVal); - binop_with_nan_check(opReduce, accuVal, currVal); + Accumulation::Calculate(accuVal, currVal); }; - posUnaryOp(accuVal); + acc_elementwise_op(accuVal, accuVal); if(!float_equal_one{}(alpha)) accuVal *= type_convert(alpha); diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 0d4c9f73d4..6cbc15c2cd 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -107,6 +107,11 @@ struct HostTensorDescriptor return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); } + std::size_t GetOffsetFromMultiIndex(std::vector iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); private: @@ -154,7 +159,7 @@ struct ParallelTensorFunctor { std::array indices; - for(int idim = 0; idim < NDIM; ++idim) + for(std::size_t idim = 0; idim < NDIM; ++idim) { indices[idim] = i / mStrides[idim]; i -= indices[idim] * mStrides[idim]; @@ -212,6 +217,54 @@ struct Tensor Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + Tensor(const Tensor& other) : mDesc(other.mDesc), mData(other.mData) {} + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + template void GenerateTensorValue(G g, std::size_t num_thread = 1) { @@ -272,6 +325,16 @@ struct Tensor return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } + T& operator()(std::vector idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + + const T& operator()(std::vector idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx)]; + } + typename std::vector::iterator begin() { return mData.begin(); } typename std::vector::iterator end() { return mData.end(); } @@ -285,7 +348,8 @@ struct Tensor }; template -HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) + : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } @@ -293,7 +357,7 @@ HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(l template HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, const std::vector& strides) - : mLens(lens), mStrides(strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } @@ -316,7 +380,7 @@ float check_error(const Tensor& ref, const Tensor& result) constexpr float eps = 1e-10; - for(int i = 0; i < ref.mData.size(); ++i) + for(std::size_t i = 0; i < ref.mData.size(); ++i) { float ref_v = ck::type_convert(ref.mData[i]); float result_v = ck::type_convert(result.mData[i]); diff --git a/library/include/ck/library/host_tensor/host_tensor_generator.hpp b/library/include/ck/library/host_tensor/host_tensor_generator.hpp index 17e20351f0..2813d6a9ae 100644 --- a/library/include/ck/library/host_tensor/host_tensor_generator.hpp +++ b/library/include/ck/library/host_tensor/host_tensor_generator.hpp @@ -18,12 +18,12 @@ struct GeneratorTensor_0 template struct GeneratorTensor_1 { - int value = 1; + T value = 1; template T operator()(Is...) { - return ck::type_convert(value); + return value; } }; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 3a706dac0b..f4944a28d2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -84,7 +84,8 @@ struct ReferenceBatchedGemm : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp new file mode 100644 index 0000000000..c6a5304766 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_cgemm.hpp @@ -0,0 +1,203 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// FIXME: support arbitrary elementwise operation for A/B/C +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct ReferenceCGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_real_{a_m_k_real}, + a_m_k_imag_{a_m_k_imag}, + b_k_n_real_{b_k_n_real}, + b_k_n_imag_{b_k_n_imag}, + c_m_n_real_{c_m_n_real}, + c_m_n_imag_{c_m_n_imag}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_real_; + const Tensor& a_m_k_imag_; + const Tensor& b_k_n_real_; + const Tensor& b_k_n_imag_; + Tensor& c_m_n_real_; + Tensor& c_m_n_imag_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceCGemm::Argument; + + float Run(const Argument& arg) + { + const std::size_t K = arg.a_m_k_real_.mDesc.GetLengths()[1]; + + if(K != arg.a_m_k_imag_.mDesc.GetLengths()[1]) + { + throw std::runtime_error("wrong! Incompatible real and imag sizes in CGEMM"); + } + + auto f_mk_kn_mn_real = [&](auto m, auto n) { + float v_c_real = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_real += v_a_real * v_b_real - v_a_imag * v_b_imag; + } + + arg.c_m_n_real_(m, n) = v_c_real; + }; + + auto f_mk_kn_mn_imag = [&](auto m, auto n) { + float v_c_imag = 0; + + for(std::size_t k = 0; k < K; ++k) + { + float v_a_real = ck::type_convert(arg.a_m_k_real_(m, k)); + float v_a_imag = ck::type_convert(arg.a_m_k_imag_(m, k)); + float v_b_real = ck::type_convert(arg.b_k_n_real_(k, n)); + float v_b_imag = ck::type_convert(arg.b_k_n_imag_(k, n)); + + v_c_imag += v_a_real * v_b_imag + v_a_imag * v_b_real; + } + + arg.c_m_n_imag_(m, n) = v_c_imag; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn_real, + arg.c_m_n_real_.mDesc.GetLengths()[0], + arg.c_m_n_real_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_mk_kn_mn_imag, + arg.c_m_n_imag_.mDesc.GetLengths()[0], + arg.c_m_n_imag_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k_real, + const Tensor& a_m_k_imag, + const Tensor& b_k_n_real, + const Tensor& b_k_n_imag, + Tensor& c_m_n_real, + Tensor& c_m_n_imag, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k_real, + a_m_k_imag, + b_k_n_real, + b_k_n_imag, + c_m_n_real, + c_m_n_imag, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceCGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp index 70f9e3617e..4203085dbc 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp @@ -1,5 +1,4 @@ -#ifndef REFERENCE_CONV_WRW_HPP -#define REFERENCE_CONV_WRW_HPP +#pragma once #include #include @@ -16,7 +15,9 @@ template + typename OutElementwiseOperation, + ck::index_t NumDimSpatial = 2, + typename ck::enable_if= 1 && NumDimSpatial <= 3, bool>::type = false> struct ReferenceConvBwdWeight : public device::BaseOperator { // Argument @@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) - : in_n_c_hi_wi_{in_n_c_hi_wi}, - wei_k_c_y_x_{wei_k_c_y_x}, - out_n_k_ho_wo_{out_n_k_ho_wo}, + : input_{in_n_c_hi_wi}, + weight_{wei_k_c_y_x}, + output_{out_n_k_ho_wo}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, @@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator { } - const Tensor& in_n_c_hi_wi_; - Tensor& wei_k_c_y_x_; - const Tensor& out_n_k_ho_wo_; + const Tensor& input_; + Tensor& weight_; + const Tensor& output_; std::vector conv_strides_; std::vector conv_dilations_; @@ -66,55 +67,184 @@ struct ReferenceConvBwdWeight : public device::BaseOperator float Run(const Argument& arg) { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - auto f_kcyx = [&](auto k, auto c, auto y, auto x) { - float v_acc = 0; - for(int n = 0; n < arg.out_n_k_ho_wo_.mDesc.GetLengths()[0]; ++n) - { - for(int ho = 0; ho < arg.out_n_k_ho_wo_.mDesc.GetLengths()[2]; ++ho) + if constexpr(NumDimSpatial == 1) + { + constexpr auto I0 = Number<0>{}; + auto f_kcx = [&](auto k, auto c, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) { - int hi = ho * arg.conv_strides_[I0] + y * arg.conv_dilations_[I0] - - arg.in_left_pads_[I0]; - for(int wo = 0; wo < arg.out_n_k_ho_wo_.mDesc.GetLengths()[3]; ++wo) + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[2]; ++wo) { - int wi = wo * arg.conv_strides_[I1] + x * arg.conv_dilations_[I1] - - arg.in_left_pads_[I1]; - if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + auto wi = + ck::type_convert(wo * arg.conv_strides_[I0]) + + ck::type_convert(x * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.mDesc.GetLengths()[2]) { float v_out; float v_in; - arg.out_element_op_( - v_out, - ck::type_convert(arg.out_n_k_ho_wo_(n, k, ho, wo))); - arg.in_element_op_( - v_in, ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.out_element_op_(v_out, + ck::type_convert(arg.output_(n, k, wo))); + arg.in_element_op_(v_in, + ck::type_convert(arg.input_(n, c, wi))); v_acc += v_out * v_in; } } } - } - float v_wei; + float v_wei; - arg.wei_element_op_(v_wei, v_acc); + arg.wei_element_op_(v_wei, v_acc); - arg.wei_k_c_y_x_(k, c, y, x) = ck::type_convert(v_wei); - }; + arg.weight_(k, c, x) = ck::type_convert(v_wei); + }; - make_ParallelTensorFunctor(f_kcyx, - arg.wei_k_c_y_x_.mDesc.GetLengths()[0], - arg.wei_k_c_y_x_.mDesc.GetLengths()[1], - arg.wei_k_c_y_x_.mDesc.GetLengths()[2], - arg.wei_k_c_y_x_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_kcx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); - return 0; + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[2]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I0]) + + ck::type_convert(y * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[3]; ++wo) + { + auto wi = + ck::type_convert(wo * arg.conv_strides_[I1]) + + ck::type_convert(x * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[3]) + { + float v_out; + float v_in; + + arg.out_element_op_( + v_out, ck::type_convert(arg.output_(n, k, ho, wo))); + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kcyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + else if constexpr(NumDimSpatial == 3) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + auto f_kczyx = [&](auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + for(std::size_t n = 0; n < arg.output_.mDesc.GetLengths()[0]; ++n) + { + for(std::size_t do_ = 0; do_ < arg.output_.mDesc.GetLengths()[2]; ++do_) + { + auto di = + ck::type_convert(do_ * arg.conv_strides_[I0]) + + ck::type_convert(z * arg.conv_dilations_[I0]) - + ck::type_convert(arg.in_left_pads_[I0]); + for(std::size_t ho = 0; ho < arg.output_.mDesc.GetLengths()[3]; ++ho) + { + auto hi = + ck::type_convert(ho * arg.conv_strides_[I1]) + + ck::type_convert(y * + arg.conv_dilations_[I1]) - + ck::type_convert(arg.in_left_pads_[I1]); + for(std::size_t wo = 0; wo < arg.output_.mDesc.GetLengths()[4]; + ++wo) + { + auto wi = + ck::type_convert(wo * + arg.conv_strides_[I2]) + + ck::type_convert( + x * arg.conv_dilations_[I2]) - + ck::type_convert(arg.in_left_pads_[I2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[4]) + { + float v_out; + float v_in; + + arg.out_element_op_(v_out, + ck::type_convert( + arg.output_(n, k, do_, ho, wo))); + arg.in_element_op_( + v_in, + ck::type_convert(arg.input_(n, c, di, hi, wi))); + + v_acc += v_out * v_in; + } + } + } + } + } + float v_wei; + + arg.wei_element_op_(v_wei, v_acc); + + arg.weight_(k, c, z, y, x) = ck::type_convert(v_wei); + }; + + make_ParallelTensorFunctor(f_kczyx, + arg.weight_.mDesc.GetLengths()[0], + arg.weight_.mDesc.GetLengths()[1], + arg.weight_.mDesc.GetLengths()[2], + arg.weight_.mDesc.GetLengths()[3], + arg.weight_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } @@ -174,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index 0f210a23e1..11252e2398 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -78,15 +78,18 @@ struct ReferenceConvBwdData : public device::BaseOperator AccDataType v_acc = 0; - for(int x = 0; x < X; ++x) + for(std::size_t x = 0; x < X; ++x) { - int w_tmp = wi + arg.in_left_pads_[0] - x * arg.conv_dilations_[0]; + auto w_tmp = ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(x * arg.conv_dilations_[0]); if(w_tmp % arg.conv_strides_[0] == 0) { - int wo = w_tmp / arg.conv_strides_[0]; - if(wo >= 0 && wo < Wo) + auto wo = ck::type_convert(w_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(wo >= 0 && ck::type_convert(wo) < Wo) { - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { AccDataType v_out = 0; AccDataType v_wei = 0; @@ -103,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator } } - float v_in; - arg.in_element_op_(v_in, v_acc); - arg.input_(n, c, wi) = ck::type_convert(v_in); + arg.in_element_op_(v_acc, v_acc); + arg.input_(n, c, wi) = ck::type_convert(v_acc); }; make_ParallelTensorFunctor(f_ncw, @@ -128,24 +130,32 @@ struct ReferenceConvBwdData : public device::BaseOperator AccDataType v_acc = 0; - for(int y = 0; y < Y; ++y) + for(std::size_t y = 0; y < Y; ++y) { - int h_tmp = hi + arg.in_left_pads_[0] - y * arg.conv_dilations_[0]; + auto h_tmp = ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(y * arg.conv_dilations_[0]); if(h_tmp % arg.conv_strides_[0] == 0) { - int ho = h_tmp / arg.conv_strides_[0]; - if(ho >= 0 && ho < Ho) + auto ho = ck::type_convert(h_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(ho >= 0 && ck::type_convert(ho) < Ho) { - for(int x = 0; x < X; ++x) + for(std::size_t x = 0; x < X; ++x) { - int w_tmp = - wi + arg.in_left_pads_[1] - x * arg.conv_dilations_[1]; + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(x * + arg.conv_dilations_[1]); if(w_tmp % arg.conv_strides_[1] == 0) { - int wo = w_tmp / arg.conv_strides_[1]; - if(wo >= 0 && wo < Wo) + auto wo = ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(wo >= 0 && ck::type_convert(wo) < Wo) { - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { AccDataType v_out = 0; AccDataType v_wei = 0; @@ -194,33 +204,49 @@ struct ReferenceConvBwdData : public device::BaseOperator AccDataType v_acc = 0; - for(int z = 0; z < Z; ++z) + for(std::size_t z = 0; z < Z; ++z) { - int d_tmp = di + arg.in_left_pads_[0] - z * arg.conv_dilations_[0]; + auto d_tmp = ck::type_convert(di) + + ck::type_convert(arg.in_left_pads_[0]) - + ck::type_convert(z * arg.conv_dilations_[0]); if(d_tmp % arg.conv_strides_[0] == 0) { - int do_ = d_tmp / arg.conv_strides_[0]; - if(do_ >= 0 && do_ < Do) + auto do_ = ck::type_convert(d_tmp) / + ck::type_convert(arg.conv_strides_[0]); + if(do_ >= 0 && ck::type_convert(do_) < Do) { - for(int y = 0; y < Y; ++y) + for(std::size_t y = 0; y < Y; ++y) { - int h_tmp = - hi + arg.in_left_pads_[1] - y * arg.conv_dilations_[1]; + auto h_tmp = + ck::type_convert(hi) + + ck::type_convert(arg.in_left_pads_[1]) - + ck::type_convert(y * + arg.conv_dilations_[1]); if(h_tmp % arg.conv_strides_[1] == 0) { - int ho = h_tmp / arg.conv_strides_[1]; - if(ho >= 0 && ho < Ho) + auto ho = ck::type_convert(h_tmp) / + ck::type_convert( + arg.conv_strides_[1]); + if(ho >= 0 && ck::type_convert(ho) < Ho) { - for(int x = 0; x < X; ++x) + for(std::size_t x = 0; x < X; ++x) { - int w_tmp = wi + arg.in_left_pads_[2] - - x * arg.conv_dilations_[2]; + auto w_tmp = + ck::type_convert(wi) + + ck::type_convert( + arg.in_left_pads_[2]) - + ck::type_convert( + x * arg.conv_dilations_[2]); if(w_tmp % arg.conv_strides_[2] == 0) { - int wo = w_tmp / arg.conv_strides_[2]; - if(wo >= 0 && wo < Wo) + auto wo = + ck::type_convert(w_tmp) / + ck::type_convert( + arg.conv_strides_[2]); + if(wo >= 0 && + ck::type_convert(wo) < Wo) { - for(int k = 0; k < K; ++k) + for(std::size_t k = 0; k < K; ++k) { AccDataType v_out = 0; AccDataType v_wei = 0; @@ -264,7 +290,8 @@ struct ReferenceConvBwdData : public device::BaseOperator } } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 0095d51a5b..d1afa898e4 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -1,9 +1,10 @@ -#ifndef REFERENCE_CONV_FWD_HPP -#define REFERENCE_CONV_FWD_HPP +#pragma once #include #include #include + +#include "stream_config.hpp" #include "device_base.hpp" #include "host_tensor.hpp" @@ -88,13 +89,16 @@ struct ReferenceConvFwd : public device::BaseOperator auto f_ncw = [&](auto n, auto k, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) { - int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) + auto wi = + ck::type_convert(wo * arg.conv_strides_[0]) + + ck::type_convert(x * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + if(wi >= 0 && + ck::type_convert(wi) < arg.input_.mDesc.GetLengths()[2]) { float v_in; float v_wei; @@ -128,18 +132,26 @@ struct ReferenceConvFwd : public device::BaseOperator auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) + for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) { - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) + auto hi = + ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) { - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.input_.mDesc.GetLengths()[3]) + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[3]) { float v_in; float v_wei; @@ -174,23 +186,37 @@ struct ReferenceConvFwd : public device::BaseOperator auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + for(std::size_t c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) + for(std::size_t z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) { - int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) + auto di = + ck::type_convert(d_o * arg.conv_strides_[0]) + + ck::type_convert(z * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) { - int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) + auto hi = + ck::type_convert(ho * arg.conv_strides_[1]) + + ck::type_convert(y * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + for(std::size_t x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) { - int wi = wo * arg.conv_strides_[2] + - x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; - if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] && - hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] && - wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4]) + auto wi = + ck::type_convert(wo * + arg.conv_strides_[2]) + + ck::type_convert(x * + arg.conv_dilations_[2]) - + ck::type_convert(arg.in_left_pads_[2]); + if(di >= 0 && + ck::type_convert(di) < + arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && + ck::type_convert(hi) < + arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && + ck::type_convert(wi) < + arg.input_.mDesc.GetLengths()[4]) { float v_in; float v_wei; @@ -226,7 +252,8 @@ struct ReferenceConvFwd : public device::BaseOperator } } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } @@ -286,4 +313,3 @@ struct ReferenceConvFwd : public device::BaseOperator } // namespace host } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp index 8f49b79a1a..4be6169c15 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation.hpp @@ -73,18 +73,25 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) { - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) { - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { float v_in; float v_wei; @@ -117,7 +124,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp index e4e0899416..466537c686 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd_bias_activation_add.hpp @@ -76,18 +76,25 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + for(std::size_t c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + for(std::size_t y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) { - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + auto hi = ck::type_convert(ho * arg.conv_strides_[0]) + + ck::type_convert(y * arg.conv_dilations_[0]) - + ck::type_convert(arg.in_left_pads_[0]); + for(std::size_t x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) { - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + auto wi = + ck::type_convert(wo * arg.conv_strides_[1]) + + ck::type_convert(x * arg.conv_dilations_[1]) - + ck::type_convert(arg.in_left_pads_[1]); + if(hi >= 0 && + ck::type_convert(hi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && + wi >= 0 && + ck::type_convert(wi) < + arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { float v_in; float v_wei; @@ -123,7 +130,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /*stream_config*/ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 1b49ca5740..6f097c6deb 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -11,6 +11,7 @@ namespace host { template @@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator auto f_mk_kn_mn = [&](auto m, auto n) { const int K = arg.a_m_k_.mDesc.GetLengths()[1]; - float v_acc = 0; + AccDataType v_acc = 0; for(int k = 0; k < K; ++k) { - float v_a; - float v_b; + AccDataType v_a; + AccDataType v_b; - arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); - arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); v_acc += v_a * v_b; } - float v_c; + AccDataType v_c; arg.c_element_op_(v_c, v_acc); @@ -80,7 +81,8 @@ struct ReferenceGemm : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp index 7dd6fc9199..a0ceb28a11 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp @@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator for(int k = 0; k < K; ++k) { - arg.a_element_op_(a, arg.a_m_k_(m, k)); - arg.b_element_op_(b, arg.b_k_n_(k, n)); + arg.a_element_op_(a, ck::type_convert(arg.a_m_k_(m, k))); + arg.b_element_op_(b, ck::type_convert(arg.b_k_n_(k, n))); acc += a * b; } @@ -82,7 +82,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp index 7c9df272c2..60f72e9e51 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation.hpp @@ -85,7 +85,8 @@ struct ReferenceGemmBiasActivation : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp index 4d3c5effae..5e0ec75e5e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_bias_activation_add.hpp @@ -91,7 +91,8 @@ struct ReferenceGemmBiasActivationAdd : public device::BaseOperator return 0; } - float Run(const device::BaseArgument* p_arg, int) override + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override { return Run(*dynamic_cast(p_arg)); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp new file mode 100644 index 0000000000..7271103d54 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp @@ -0,0 +1,162 @@ +#pragma once +#include +#include +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceSoftmax : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const index_t rank, + const std::vector sm_reduce_dims) + : in_(in), out_(out), alpha_(alpha), beta_(beta), sm_reduce_dims_(sm_reduce_dims) + { + // std::cout << "debug: scalar dims: "; + for(int i = 0; i < rank; i++) + { + if(std::find(sm_reduce_dims.begin(), sm_reduce_dims.end(), i) == + sm_reduce_dims.end()) + { + sm_scalar_dims_.push_back(i); + // std::cout << i << ", "; + } + } + // std::cout << std::endl; + } + + const Tensor& in_; + Tensor& out_; + AccDataType alpha_; + AccDataType beta_; + index_t rank_; + std::vector sm_reduce_dims_; + std::vector sm_scalar_dims_; // dim after internal max/sum reduction + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + float Run(const Argument& arg) + { + std::vector scalar_lengths; + for(index_t dim : arg.sm_scalar_dims_) + { + scalar_lengths.push_back(arg.in_.mDesc.GetLengths()[dim]); + } + + Tensor reduce_max(scalar_lengths); + reduce_max.GenerateTensorValue( + GeneratorTensor_1{std::numeric_limits::lowest()}); + Tensor reduce_sum(scalar_lengths); + reduce_sum.GenerateTensorValue(GeneratorTensor_1{0}); + + auto to_sm_scalar_idx = [&](auto idx) { + std::vector sm_scalar_idx; + for(index_t dim : arg.sm_scalar_dims_) + { + sm_scalar_idx.push_back(idx[dim]); + } + return sm_scalar_idx; + }; + + arg.in_.ForEach([&](auto& self, auto idx) { + reduce_max(to_sm_scalar_idx(idx)) = std::max(reduce_max(to_sm_scalar_idx(idx)), + static_cast(self(idx))); + }); + + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; + + Tensor in_stable(arg.in_.mDesc); + in_stable.ForEach([&](auto& self, auto idx) { + // numerator = exp(x - max(x)) + self(idx) = std::exp(static_cast(arg.in_(idx)) - + reduce_max(to_sm_scalar_idx(idx))); + }); + + // LogRangeAsType(std::cout << "in_stable: ", in_stable.mData, ",") << std::endl; + + in_stable.ForEach([&](auto& self, auto idx) { + // denominator = sum(exp(x - max(x))) + reduce_sum(to_sm_scalar_idx(idx)) += self(idx); + }); + + // LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") << + // std::endl; + + arg.out_.ForEach([&](auto& self, auto idx) { + self(idx) = arg.alpha_ * in_stable(idx) / reduce_sum(to_sm_scalar_idx(idx)) + + arg.beta_ * self(idx); + }); + + // LogRangeAsType(std::cout << "out: ", arg.out_.mData, ",") << std::endl; + // reduction along reduce dims + // LogRangeAsType(std::cout << "reduce_max: ", reduce_max.mData, ",") << + // std::endl; LogRangeAsType(std::cout << "reduce_sum: ", reduce_sum.mData, ",") + // << std::endl; + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& in, + Tensor& out, + AccDataType alpha, + AccDataType beta, + const index_t rank, + const std::vector sm_reduce_dims) + { + return Argument{in, out, alpha, beta, rank, sm_reduce_dims}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceSoftmax" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp index 40fd7274ef..13b6166107 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance.hpp @@ -1,7 +1,6 @@ -#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP -#define CK_DEVICE_OPERATION_INSTANCE_HPP +#pragma once -#include +#include namespace ck { namespace tensor_operation { @@ -23,4 +22,3 @@ void add_device_operation_instances(std::vector>& op } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp index fafbe120b9..6f0dbe75ff 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -9,26 +9,11 @@ #include "device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "device_reduce_instance_blockwise_b16_f32_b16.hpp" -#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp" -#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp" -#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp" -#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp" -#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp" -#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp" -#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp" -#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" +#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp" #include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp" -#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp index e4b06cf96d..0f8c365007 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise.hpp @@ -3,13 +3,27 @@ #include "reduction_operator_mapping.hpp" #include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_blockwise.hpp" +#include "device_reduce_multiblock.hpp" namespace ck { namespace tensor_operation { namespace device { namespace device_reduce_instance { +using reduce_configuration_1_instances_blockwise = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + #ifdef QUICK_REDUCE_TEST using reduce_configuration_2_instances_blockwise = std::tuple< // clang-format off @@ -47,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple< >; #endif -template +template using deviceReduceBlockWisePtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template + bool PropagateNan, + bool UseIndex> void add_device_reduce_instance_blockwise( - std::vector>& device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); + constexpr bool OutputIndex = Indexable && UseIndex; - constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; + static_for<0, std::tuple_size::value, 1>{}( + [&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_blockwise{}))>; - static_for<0, std::tuple_size::value, 1>{}([&](auto i) { - using cfg1 = - remove_cvref_t(reduce_configuration_1_instances{}))>; + static_for<0, std::tuple_size::value, 1>{}( + [&](auto j) { + using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; - static_for<0, std::tuple_size::value, 1>{}( - [&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise{}))>; + using ReduceOpInstance = + DeviceReduceMultiBlock; - using ReduceOpInstance = DeviceReduceBlockWise; - - device_op_instances.push_back( - std::make_unique(ReduceOpInstance{})); - }); - }); + device_op_instances.push_back( + std::make_unique(ReduceOpInstance{})); + }); + }); }; -#define ADD_BLOCKWISE_INST_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - template void add_device_reduce_instance_blockwise( \ - std::vector> & device_op_instances) +#define ADD_BLOCKWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) -#define ADD_BLOCKWISE_INST_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_BLOCKWISE_INST_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_BLOCKWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) -#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_blockwise( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_blockwise( \ + std::vector> & device_op_instances) -#define ADD_BLOCKWISE_INST_REF_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_BLOCKWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) } // namespace device_reduce_instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp index 0ae3289a0d..3cad45f2e5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp index e7bdb15d92..441c1aec3f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp index dad0d86350..ca8532a458 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp index 34ec15db2b..64f504c9da 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp index b08f35ad09..9e84ee34fb 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp index 65cdd45340..a37e3bdeb9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp index f4a6677b3e..1d8695bbb0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp index 7f67138e6b..b5c19b7207 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_blockwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call.hpp deleted file mode 100644 index 8e47bbfb6a..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call.hpp +++ /dev/null @@ -1,165 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP - -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -#ifdef QUICK_REDUCE_TEST -using reduce_configuration_2_instances_blockwise_second_call = std::tuple< - // clang-format off - // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize - ReductionConfiguration_2<1, 2, 1, 1, 2>, - ReductionConfiguration_2<1, 1, 1, 1, 3> - // clang-format on - >; -#else -using reduce_configuration_2_instances_blockwise_second_call = std::tuple< - // clang-format off - // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize - ReductionConfiguration_2<1, 4, 1, 1, 8>, - ReductionConfiguration_2<1, 4, 1, 1, 4>, - ReductionConfiguration_2<1, 2, 1, 1, 2>, - - ReductionConfiguration_2<1, 1, 1, 1, 3>, - ReductionConfiguration_2<1, 1, 1, 1, 5>, - ReductionConfiguration_2<1, 1, 1, 1, 7>, - ReductionConfiguration_2<1, 1, 1, 1, 11> - // clang-format on - >; -#endif - -template -using deviceReduceBlockWiseSecondCallPtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; - -template -void add_device_reduce_instance_blockwise_second_call( - std::vector>& - device_op_instances) -{ - using ReduceOperation = typename reduce_binary_operator::opType; - using InElementwiseOperation = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; - - constexpr bool Indexable = - (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || - ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); - - constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; - - static_assert(std::is_same::value, - "InDataType and AccDataType should be the same to use " - "add_device_reduce_instance_blockwise_second_call!"); - - static_for<0, std::tuple_size::value, 1>{}([&](auto i) { - using cfg1 = - remove_cvref_t(reduce_configuration_1_instances{}))>; - - static_for<0, - std::tuple_size::value, - 1>{}([&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_blockwise_second_call{}))>; - - using ReduceOpInstance = DeviceReduceBlockWiseSecondCall; - - device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); - }); - }); -}; - -#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - template void add_device_reduce_instance_blockwise_second_call( \ - std::vector> & \ - device_op_instances) - -#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ - NumReduceDim) - -#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_blockwise_second_call( \ - std::vector< \ - DeviceReducePtr:: \ - InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) - -#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ - NumReduceDim) - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp deleted file mode 100644 index 4ce19c7d0c..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp deleted file mode 100644 index c85419befc..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1); - -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp deleted file mode 100644 index d42e7e020f..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp deleted file mode 100644 index fcf244d1d3..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp deleted file mode 100644 index 72e806ee60..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp deleted file mode 100644 index 476c3a7d8f..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp deleted file mode 100644 index d46780483b..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp deleted file mode 100644 index 7b020fb439..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I8_I8_I8_HPP -#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I8_I8_I8_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp index b25645034c..721d98a718 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_impl_common.hpp @@ -30,20 +30,6 @@ struct ReductionConfiguration_2 static constexpr int KThreadSliceSize_ = KThreadSliceSize; }; -using reduce_configuration_1_instances = std::tuple< - // clang-format off - // BlockSize | MThreadClusterSize | KThreadClusterSize - ReductionConfiguration_1<256, 128, 2>, - ReductionConfiguration_1<256, 64, 4>, - ReductionConfiguration_1<256, 32, 8>, - ReductionConfiguration_1<256, 16, 16>, - ReductionConfiguration_1<256, 8, 32>, - ReductionConfiguration_1<256, 4, 64>, - ReductionConfiguration_1<256, 2, 128>, - ReductionConfiguration_1<256, 1, 256> - // clang-format on - >; - #define QUICK_REDUCE_TEST 1 } // namespace device_reduce_instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp index bf10080b5e..9f78933bde 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add.hpp @@ -3,13 +3,27 @@ #include "reduction_operator_mapping.hpp" #include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_multiblock_atomic_add.hpp" +#include "device_reduce_multiblock.hpp" namespace ck { namespace tensor_operation { namespace device { namespace device_reduce_instance { +using reduce_configuration_1_instances_multiblock_atomic_add = std::tuple< + // clang-format off + // BlockSize | MThreadClusterSize | KThreadClusterSize + ReductionConfiguration_1<256, 128, 2>, + ReductionConfiguration_1<256, 64, 4>, + ReductionConfiguration_1<256, 32, 8>, + ReductionConfiguration_1<256, 16, 16>, + ReductionConfiguration_1<256, 8, 32>, + ReductionConfiguration_1<256, 4, 64>, + ReductionConfiguration_1<256, 2, 128>, + ReductionConfiguration_1<256, 1, 256> + // clang-format on + >; + #ifdef QUICK_REDUCE_TEST using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< // clang-format off @@ -47,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< >; #endif -template -using deviceReduceMultiBlockAtomicAddPtrType = - DeviceReducePtr:: - InElementwiseOperation, - typename reduce_unary_operator:: - AccElementwiseOperation>; +template +using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr< + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template + bool PropagateNan, + bool UseIndex> void add_device_reduce_instance_multiblock_atomic_add( - std::vector>& - device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); + constexpr bool OutputIndex = Indexable && UseIndex; - constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; - - static_assert(IndicesOpt == ReduceTensorIndices::NO_INDICES, - "AtomicAdd can only be used with reduction operations without indices!"); + static_assert(UseIndex == false, + "AtomicAdd can only be used with reduction operations using no index!"); constexpr bool op_acceptable = (ReduceOpId == ReduceTensorOp::ADD || ReduceOpId == ReduceTensorOp::MUL || @@ -94,9 +102,11 @@ void add_device_reduce_instance_multiblock_atomic_add( return; else { - static_for<0, std::tuple_size::value, 1>{}([&](auto i) { - using cfg1 = - remove_cvref_t(reduce_configuration_1_instances{}))>; + static_for<0, + std::tuple_size::value, + 1>{}([&](auto i) { + using cfg1 = remove_cvref_t(reduce_configuration_1_instances_multiblock_atomic_add{}))>; static_for< 0, @@ -105,24 +115,27 @@ void add_device_reduce_instance_multiblock_atomic_add( using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_atomic_add{}))>; - using ReduceOpInstance = DeviceReduceMultiBlockAtomicAdd; + using ReduceOpInstance = + DeviceReduceMultiBlock; device_op_instances.push_back( std::make_unique(ReduceOpInstance{})); @@ -132,54 +145,49 @@ void add_device_reduce_instance_multiblock_atomic_add( }; #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ template void add_device_reduce_instance_multiblock_atomic_add( \ - std::vector> & \ - device_op_instances) + PropagateNan, \ + UseIndex>( \ + std::vector> & device_op_instances) -#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) -#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_multiblock_atomic_add( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_multiblock_atomic_add( \ + std::vector> & device_op_instances) -#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) } // namespace device_reduce_instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp index 58f90bb94f..4e39cf49f6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_B16_F32_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp index f4c766ca03..73424322ae 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F16_F32_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp index c2f2564fc9..ecc9c4ea87 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F32_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp index 830dcf9407..41a60d5b70 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F32_F64_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_multiblock_atomic_add.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp new file mode 100644 index 0000000000..bdcca274d7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp @@ -0,0 +1,29 @@ +#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP +#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_ATOMIC_ADD_F64_F64_F64_HPP + +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck + +#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce.hpp deleted file mode 100644 index 5c323ec175..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce.hpp +++ /dev/null @@ -1,174 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_HPP - -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_impl_common.hpp" -#include "device_reduce_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -#ifdef QUICK_REDUCE_TEST -using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple< - // clang-format off - // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize - ReductionConfiguration_2<0, 1, 1, 2, 1>, - ReductionConfiguration_2<1, 2, 1, 1, 2>, - ReductionConfiguration_2<0, 1, 1, 3, 1>, - ReductionConfiguration_2<1, 1, 1, 1, 3> - // clang-format on - >; -#else -using reduce_configuration_2_instances_multiblock_partial_reduce = std::tuple< - // clang-format off - // InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize - ReductionConfiguration_2<0, 4, 1, 8, 1>, - ReductionConfiguration_2<0, 4, 1, 4, 1>, - ReductionConfiguration_2<0, 2, 1, 2, 1>, - - ReductionConfiguration_2<1, 4, 1, 1, 8>, - ReductionConfiguration_2<1, 4, 1, 1, 4>, - ReductionConfiguration_2<1, 2, 1, 1, 2>, - - // special instances - ReductionConfiguration_2<0, 1, 1, 3, 1>, - ReductionConfiguration_2<0, 1, 1, 5, 1>, - ReductionConfiguration_2<0, 1, 1, 7, 1>, - ReductionConfiguration_2<0, 1, 1, 11, 1>, - - ReductionConfiguration_2<0, 1, 1, 1, 3>, - ReductionConfiguration_2<0, 1, 1, 1, 5>, - ReductionConfiguration_2<0, 1, 1, 1, 7>, - ReductionConfiguration_2<0, 1, 1, 1, 11> - // clang-format on - >; -#endif - -template -using deviceReduceMultiBlockPartialReducePtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; - -template -void add_device_reduce_instance_multiblock_partial_reduce( - std::vector>& - device_op_instances) -{ - using ReduceOperation = typename reduce_binary_operator::opType; - using InElementwiseOperation = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; - - constexpr bool Indexable = - (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || - ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); - - constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; - - static_for<0, std::tuple_size::value, 1>{}([&](auto i) { - using cfg1 = - remove_cvref_t(reduce_configuration_1_instances{}))>; - - static_for< - 0, - std::tuple_size::value, - 1>{}([&](auto j) { - using cfg2 = remove_cvref_t(reduce_configuration_2_instances_multiblock_partial_reduce{}))>; - - using ReduceOpInstance = DeviceReduceMultiBlockPartialReduce; - - device_op_instances.push_back(std::make_unique(ReduceOpInstance{})); - }); - }); -}; - -#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - template void add_device_reduce_instance_multiblock_partial_reduce( \ - std::vector> & \ - device_op_instances) - -#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ - NumReduceDim) - -#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_multiblock_partial_reduce( \ - std::vector< \ - DeviceReducePtr:: \ - InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) - -#define ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ - NumReduceDim) - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp deleted file mode 100644 index d25645ad1e..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp +++ /dev/null @@ -1,60 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_B16_F32_B16_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_B16_F32_B16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp deleted file mode 100644 index 05549fc702..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F16_F16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp deleted file mode 100644 index 3e4aaef51b..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F16_F32_F16_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp deleted file mode 100644 index 2a1e4e7bf0..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp +++ /dev/null @@ -1,52 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F32_F32_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp deleted file mode 100644 index f95e3001ee..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F32_F64_F32_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(float, double, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp deleted file mode 100644 index fac65128b6..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_F64_F64_F64_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1); - -// Will be moved to use MultiBlockAtomicAdd -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp deleted file mode 100644 index 895c144c66..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp +++ /dev/null @@ -1,31 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I32_I8_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I32_I8_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp deleted file mode 100644 index d6bee57fcd..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp +++ /dev/null @@ -1,47 +0,0 @@ -#ifndef DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I8_I8_HPP -#define DEVICE_REDUCE_INSTANCE_MULTIBLOCK_PARTIAL_REDUCE_I8_I8_I8_HPP - -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_REF_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck - -#endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp index f3a0781c2b..563dd09b10 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise.hpp @@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple< >; #endif -template +template using deviceReduceThreadWisePtrType = DeviceReducePtr< - typename reduce_unary_operator::InElementwiseOperation, - typename reduce_unary_operator::AccElementwiseOperation>; + typename reduce_unary_operator::InElementwiseOperation, + typename reduce_unary_operator::AccElementwiseOperation>; template + bool PropagateNan, + bool UseIndex> void add_device_reduce_instance_threadwise( - std::vector>& device_op_instances) + std::vector>& device_op_instances) { - using ReduceOperation = typename reduce_binary_operator::opType; + using ReduceOperation = typename reduce_binary_operator::opType; using InElementwiseOperation = - typename reduce_unary_operator::InElementwiseOperation; + typename reduce_unary_operator::InElementwiseOperation; using AccElementwiseOperation = - typename reduce_unary_operator:: - AccElementwiseOperation; + typename reduce_unary_operator::AccElementwiseOperation; constexpr bool Indexable = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); - - constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; + constexpr bool OutputIndex = Indexable && UseIndex; using cfg1 = ReductionConfiguration_1<256, 256, 1>; @@ -93,10 +90,9 @@ void add_device_reduce_instance_threadwise( InElementwiseOperation, AccElementwiseOperation, PropagateNan, - NeedIndices, + OutputIndex, + false, // HaveIndexInputIfOutputIndex cfg1::BlockSize_, - cfg1::MThreadClusterSize_, - cfg1::KThreadClusterSize_, cfg2::MThreadSliceSize_, cfg2::KThreadSliceSize_, cfg2::InSrcVectorDim_, @@ -107,54 +103,50 @@ void add_device_reduce_instance_threadwise( }); }; -#define ADD_THREADWISE_INST_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - template void add_device_reduce_instance_threadwise( \ - std::vector> & device_op_instances) +#define ADD_THREADWISE_INST_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) -#define ADD_THREADWISE_INST_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_THREADWISE_INST_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_THREADWISE_INST_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) -#define ADD_THREADWISE_INST_REF_BY_TYPE( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - extern template void add_device_reduce_instance_threadwise( \ - std::vector::InElementwiseOperation, \ - typename reduce_unary_operator:: \ - AccElementwiseOperation>> & \ - device_op_instances) +#define ADD_THREADWISE_INST_REF_BY_TYPE( \ + inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ + extern template void add_device_reduce_instance_threadwise( \ + std::vector> & device_op_instances) -#define ADD_THREADWISE_INST_REF_BY_ID( \ - inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ - ADD_THREADWISE_INST_REF_BY_TYPE(inT, \ - compT, \ - outT, \ - static_cast(ReduceOpId), \ - static_cast(NanOpt), \ - static_cast(IndicesOpt), \ - Rank, \ +#define ADD_THREADWISE_INST_REF_BY_ID( \ + inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ + ADD_THREADWISE_INST_REF_BY_TYPE(inT, \ + compT, \ + outT, \ + static_cast(ReduceOpId), \ + static_cast(NanOpt), \ + static_cast(IndicesOpt), \ + Rank, \ NumReduceDim) } // namespace device_reduce_instance diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp index f11d9118c9..0291f33214 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_B16_F32_B16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp index fe220335c5..7ab1bebc5f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F16_F16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp index 970559cfac..39c3d10660 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16.hpp @@ -1,8 +1,7 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_F16_F32_F16_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" +#include "data_type.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp index 66c33a72a4..3c47bfd189 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F32_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp index 196f142dbf..9df9f6f1fa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_F32_F64_F32_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp index 4f3e1448d0..00ab218f20 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_F64_F64_F64_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp index 8f19a5d0a2..de7445b043 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I32_I8_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp index 83bd48cd3f..1ea1ee745e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8.hpp @@ -1,8 +1,6 @@ #ifndef DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP #define DEVICE_REDUCE_INSTANCE_THREADWISE_I8_I8_I8_HPP -#include "reduction_enums.hpp" -#include "reduction_operator_mapping.hpp" #include "device_reduce_instance_threadwise.hpp" namespace ck { diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 0442fd1003..368da4d207 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,4 +1,5 @@ #pragma once + #include #include #include @@ -22,7 +23,7 @@ check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, - double atol = 1e-8) + double atol = 3e-6) { if(out.size() != ref.size()) { @@ -167,20 +168,34 @@ check_err(const std::vector& out, return false; } + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { - const auto out_v = static_cast(out[i]); - const auto ref_v = static_cast(ref[i]); + int64_t o = out[i]; + int64_t r = ref[i]; + err = std::abs(o - r); - if(out_v != ref_v) + if(err > 0) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out_v << " != " << ref_v - << std::endl - << msg << std::endl; - return false; + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl + << msg << std::endl; + } + res = false; } } - return true; + if(!res) + { + std::cout << "max err: " << max_err << std::endl; + } + return res; } } // namespace utils diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_util.hpp similarity index 94% rename from library/include/ck/library/utility/conv_fwd_util.hpp rename to library/include/ck/library/utility/conv_util.hpp index a29eb814fd..409fa5aff2 100644 --- a/library/include/ck/library/utility/conv_fwd_util.hpp +++ b/library/include/ck/library/utility/conv_util.hpp @@ -146,19 +146,19 @@ struct ConvParams const std::vector& left_pads, const std::vector& right_pads); - ck::index_t num_dim_spatial; - ck::index_t N; - ck::index_t K; - ck::index_t C; + ck::index_t num_dim_spatial_; + ck::index_t N_; + ck::index_t K_; + ck::index_t C_; - std::vector filter_spatial_lengths; - std::vector input_spatial_lengths; + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; - std::vector conv_filter_strides; - std::vector conv_filter_dilations; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; - std::vector input_left_pads; - std::vector input_right_pads; + std::vector input_left_pads_; + std::vector input_right_pads_; std::vector GetOutputSpatialLengths() const; }; @@ -268,10 +268,10 @@ void run_reference_convolution_forward(const ConvParams& params, auto ref_argument = ref_conv.MakeArgument(input, weights, output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, PassThrough{}, PassThrough{}, PassThrough{}); @@ -402,8 +402,8 @@ template , - typename WeightsInitFun = FillUniform> + typename InputInitFun = FillUniformDistribution, + typename WeightsInitFun = FillUniformDistribution> class ConvFwdOpInstance : public ck::utils::OpInstance { using DeviceConvFwdOp = tensor_operation::device:: @@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance input_dims{static_cast(params_.N), - static_cast(params_.C)}; + std::vector input_dims{static_cast(params_.N_), + static_cast(params_.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params_.input_spatial_lengths), - std::end(params_.input_spatial_lengths)); + std::begin(params_.input_spatial_lengths_), + std::end(params_.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params_.K), - static_cast(params_.C)}; + std::vector filter_dims{static_cast(params_.K_), + static_cast(params_.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params_.filter_spatial_lengths), - std::end(params_.filter_spatial_lengths)); + std::begin(params_.filter_spatial_lengths_), + std::end(params_.filter_spatial_lengths_)); auto input = std::make_unique>( get_host_tensor_descriptor(input_dims, InLayout{})); @@ -465,8 +465,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance GetOutputTensor() const override { - std::vector output_dims{static_cast(params_.N), - static_cast(params_.K)}; + std::vector output_dims{static_cast(params_.N_), + static_cast(params_.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths_), std::end(output_spatial_lengths_)); @@ -522,16 +522,16 @@ class ConvFwdOpInstance : public ck::utils::OpInstance(in_device_buffers[0]->GetDeviceBuffer()), static_cast(in_device_buffers[1]->GetDeviceBuffer()), static_cast(out_device_buffer->GetDeviceBuffer()), - params_.N, - params_.K, - params_.C, - params_.input_spatial_lengths, - params_.filter_spatial_lengths, + params_.N_, + params_.K_, + params_.C_, + params_.input_spatial_lengths_, + params_.filter_spatial_lengths_, output_spatial_lengths_, - params_.conv_filter_strides, - params_.conv_filter_dilations, - params_.input_left_pads, - params_.input_right_pads, + params_.conv_filter_strides_, + params_.conv_filter_dilations_, + params_.input_left_pads_, + params_.input_right_pads_, InElementwiseOp{}, WeiElementwiseOp{}, OutElementwiseOp{}); @@ -539,20 +539,20 @@ class ConvFwdOpInstance : public ck::utils::OpInstance(params_.N, - params_.C, - params_.K, - params_.input_spatial_lengths, - params_.filter_spatial_lengths, + return get_btype(params_.N_, + params_.C_, + params_.K_, + params_.input_spatial_lengths_, + params_.filter_spatial_lengths_, output_spatial_lengths_); } @@ -560,8 +560,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance output_spatial_lengths_; const bool do_init_; - const InputInitFun& input_init_f_; - const WeightsInitFun& weights_init_f_; + InputInitFun input_init_f_; + WeightsInitFun weights_init_f_; }; } // namespace conv diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index f44aec969d..8c31e56beb 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "data_type.hpp" @@ -8,46 +9,56 @@ namespace ck { namespace utils { -// template -// struct FillUniform; - -// TODO: what's wrong with this specialization??? -// err: segmentation fault in mt19937 - infinite loop like. -// template -// struct FillUniform::value && -// !std::is_same::value>::type> -// { -// int a_{0}; -// int b_{5}; -// // T a_ = T{0}; -// // T b_ = T{5}; - -// template -// void operator()(ForwardIter first, ForwardIter last) const -// { -// std::mt19937 gen{11939}; -// std::uniform_int_distribution dis(a_, b_); -// std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); -// } -// }; - -// struct FillUniform::value || -// std::is_same::value>::type> template -struct FillUniform +struct FillUniformDistribution { - float a_{0}; - float b_{5}; + float a_{-5.f}; + float b_{5.f}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen{11939}; - std::uniform_real_distribution<> dis(a_, b_); + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); } }; +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } +}; + template struct FillMonotonicSeq { diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp index ec88b4e1b9..1d11b62a4a 100644 --- a/library/include/ck/library/utility/op_instance_engine.hpp +++ b/library/include/ck/library/utility/op_instance_engine.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -78,7 +79,8 @@ class OpInstanceRunEngine template > OpInstanceRunEngine(const OpInstanceT& op_instance, - const ReferenceOp& reference_op = ReferenceOp{}) + const ReferenceOp& reference_op = ReferenceOp{}, + bool do_verification = true) : op_instance_{op_instance} { in_tensors_ = op_instance_.GetInputTensors(); @@ -88,8 +90,11 @@ class OpInstanceRunEngine const Tensor&..., Tensor&>) { - ref_output_ = op_instance_.GetOutputTensor(); - CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + if(do_verification) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } } AllocateDeviceInputTensors(std::make_index_sequence{}); out_device_buffer_ = @@ -110,6 +115,7 @@ class OpInstanceRunEngine op_ptr.get(), in_device_buffers_, out_device_buffer_); if(op_ptr->IsSupportedArgument(argument.get())) { + std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; invoker->Run(argument.get()); out_device_buffer_->FromDevice(out_tensor_->mData.data()); if(!ref_output_) @@ -119,20 +125,26 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && check_err(out_tensor_->mData, ref_output_->mData); + bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); + std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; + res = res && inst_res; out_device_buffer_->SetZero(); } + else + { + std::cout << "Given conv problem is not supported by instance: \n\t>>>>" + << op_ptr->GetTypeString() << std::endl; + } } return res; } template ProfileBestConfig Profile(const std::vector& op_ptrs, - int nrepeat = 100, + bool time_kernel = false, bool do_verification = false, bool do_log = false) { - bool res{true}; ProfileBestConfig best_config; for(auto& op_ptr : op_ptrs) @@ -143,7 +155,7 @@ class OpInstanceRunEngine if(op_ptr->IsSupportedArgument(argument.get())) { std::string op_name = op_ptr->GetTypeString(); - float avg_time = invoker->Run(argument.get(), nrepeat); + float avg_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); std::size_t flops = op_instance_.GetFlops(); std::size_t num_btype = op_instance_.GetBtype(); @@ -153,7 +165,7 @@ class OpInstanceRunEngine std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; - if(tflops < best_config.best_tflops) + if(avg_time < best_config.best_avg_time) { best_config.best_op_name = op_name; best_config.best_tflops = tflops; @@ -171,7 +183,7 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && CheckErr(out_tensor_->mData, ref_output_->mData); + CheckErr(out_tensor_->mData, ref_output_->mData); if(do_log) {} } @@ -223,7 +235,7 @@ class OpInstanceRunEngine template bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const { - return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_); + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); } }; diff --git a/library/src/host_tensor/CMakeLists.txt b/library/src/host_tensor/CMakeLists.txt index fd100e477f..2a020b763d 100644 --- a/library/src/host_tensor/CMakeLists.txt +++ b/library/src/host_tensor/CMakeLists.txt @@ -10,10 +10,31 @@ set(HOST_TENSOR_SOURCE host_tensor.cpp ) -add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) +add_library(host_tensor STATIC ${HOST_TENSOR_SOURCE}) +add_library(composable_kernel::host_tensor ALIAS host_tensor) + target_compile_features(host_tensor PUBLIC) set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) target_include_directories(host_tensor SYSTEM PUBLIC $) -install(TARGETS host_tensor LIBRARY DESTINATION lib) + +target_include_directories(host_tensor PUBLIC + "$" + "$" + "$" +) + +install(TARGETS host_tensor + EXPORT host_tensorTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) + +install(EXPORT host_tensorTargets + FILE composable_kernelhost_tensorTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) clang_tidy_check(host_tensor) diff --git a/library/src/host_tensor/device.cpp b/library/src/host_tensor/device.cpp index 3e80df80fb..9f0d982dbc 100644 --- a/library/src/host_tensor/device.cpp +++ b/library/src/host_tensor/device.cpp @@ -2,7 +2,7 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { - hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); + hip_check_error(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); } void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } @@ -11,49 +11,48 @@ std::size_t DeviceMem::GetBufferSize() { return mMemSize; } void DeviceMem::ToDevice(const void* p) { - hipGetErrorString( - hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); + hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } void DeviceMem::FromDevice(void* p) { - hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); + hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } -void DeviceMem::SetZero() { hipGetErrorString(hipMemset(mpDeviceBuf, 0, mMemSize)); } +void DeviceMem::SetZero() { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } -DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } +DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } struct KernelTimerImpl { KernelTimerImpl() { - hipGetErrorString(hipEventCreate(&mStart)); - hipGetErrorString(hipEventCreate(&mEnd)); + hip_check_error(hipEventCreate(&mStart)); + hip_check_error(hipEventCreate(&mEnd)); } ~KernelTimerImpl() { - hipGetErrorString(hipEventDestroy(mStart)); - hipGetErrorString(hipEventDestroy(mEnd)); + hip_check_error(hipEventDestroy(mStart)); + hip_check_error(hipEventDestroy(mEnd)); } void Start() { - hipGetErrorString(hipDeviceSynchronize()); - hipGetErrorString(hipEventRecord(mStart, nullptr)); + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(mStart, nullptr)); } void End() { - hipGetErrorString(hipEventRecord(mEnd, nullptr)); - hipGetErrorString(hipEventSynchronize(mEnd)); + hip_check_error(hipEventRecord(mEnd, nullptr)); + hip_check_error(hipEventSynchronize(mEnd)); } float GetElapsedTime() const { float time; - hipGetErrorString(hipEventElapsedTime(&time, mStart, mEnd)); + hip_check_error(hipEventElapsedTime(&time, mStart, mEnd)); return time; } diff --git a/library/src/host_tensor/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp index 38b0796635..138e3fc254 100644 --- a/library/src/host_tensor/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -25,7 +25,7 @@ std::size_t HostTensorDescriptor::GetElementSize() const std::size_t HostTensorDescriptor::GetElementSpace() const { std::size_t space = 1; - for(int i = 0; i < mLens.size(); ++i) + for(std::size_t i = 0; i < mLens.size(); ++i) { space += (mLens[i] - 1) * mStrides[i]; } @@ -68,7 +68,7 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream // FIXME: remove void bf16_to_f32_(const Tensor& src, Tensor& dst) { - for(int i = 0; i < src.mData.size(); ++i) + for(std::size_t i = 0; i < src.mData.size(); ++i) dst.mData[i] = ck::type_convert(src.mData[i]); } #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 7b361b48bd..128aea334a 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform @@ -11,6 +12,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/thread ${PROJECT_SOURCE_DIR}/include/ck/tensor_operation/gpu/element ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor + ${PROJECT_SOURCE_DIR}/library/include/ck/library/host ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce ${PROJECT_SOURCE_DIR}/external/include/half @@ -18,7 +20,7 @@ include_directories(BEFORE function(add_instance_library INSTANCE_NAME) message("adding instance ${INSTANCE_NAME}") - add_library(${INSTANCE_NAME} SHARED ${ARGN}) + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) target_compile_features(${INSTANCE_NAME} PUBLIC) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) endfunction(add_instance_library INSTANCE_NAME) @@ -28,6 +30,7 @@ add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_reduce) +add_subdirectory(gemm_bias_add_reduce) add_subdirectory(batched_gemm) add_subdirectory(conv1d_fwd) add_subdirectory(conv2d_fwd) @@ -41,3 +44,77 @@ add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_gemm) add_subdirectory(conv2d_bwd_weight) add_subdirectory(batched_gemm_reduce) +add_subdirectory(gemm_add_add_fastgelu) + +add_library(device_operations STATIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + device_conv2d.cpp +) +add_library(composablekernels::device_operations ALIAS device_operations) + + +set(DEV_OPS_INC_DIRS + ${PROJECT_SOURCE_DIR}/include/ck/ + ${PROJECT_SOURCE_DIR}/library/include/ck/ + ${PROJECT_SOURCE_DIR}/external/include/ +) +target_compile_features(device_operations PUBLIC) +set_target_properties(device_operations PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(device_operations PUBLIC + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) + +#once new arches are enabled make this an option on the main cmake file +# and pass down here to be exported + +target_compile_options(device_operations PRIVATE + --offload-arch=gfx908 + --offload-arch=gfx90a +) + +# install(TARGETS device_operations LIBRARY DESTINATION lib) +install(TARGETS device_operations + EXPORT device_operationsTargets + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} +) +install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck) +install(EXPORT device_operationsTargets + FILE composable_kerneldevice_operationsTargets.cmake + NAMESPACE composable_kernel:: + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel +) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 35e24462b5..016c85f673 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -18,9 +18,9 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp; ) -add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) -target_compile_features(device_batched_gemm_instance PUBLIC) +add_library(device_batched_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) +# target_compile_features(device_batched_gemm_instance PUBLIC) set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) +# install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) clang_tidy_check(device_batched_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 59eb6cb1cc..0606df01f1 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,11 +1,12 @@ -set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE +set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp ) -add_instance_library(device_batched_gemm_reduce_instance ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE}) -install(TARGETS device_batched_gemm_reduce_instance LIBRARY DESTINATION lib) +add_instance_library(device_batched_gemm_reduce_instance OBJECT ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE}) +target_compile_features(device_batched_gemm_reduce_instance PUBLIC) +set_target_properties(device_batched_gemm_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(device_batched_gemm_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp index 3653169921..886863c73b 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,42 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[g, m, n] = a[g, m, k] * b[g, n, k] -// d0[g, m] = reduce0(c[g, m, n]) -// d1[g, m] = reduce1(c[g, m, n]) using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp index 070056980d..b5ddc43838 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,42 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[g, m, n] = a[g, m, k] * b[g, n, k] -// d0[g, m] = reduce0(c[g, m, n]) -// d1[g, m] = reduce1(c[g, m, n]) using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp index f242b3c12e..8426ab79c9 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,42 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[g, m, n] = a[g, m, k] * b[g, n, k] -// d0[g, m] = reduce0(c[g, m, n]) -// d1[g, m] = reduce1(c[g, m, n]) using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp index cbf3c16171..7cd1908803 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,39 +21,47 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[g, m, n] = a[g, m, k] * b[g, n, k] -// d0[g, m] = reduce0(c[g, m, n]) -// d1[g, m] = reduce1(c[g, m, n]) using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = std::tuple< // clang-format off - //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt index 6c7c3e4f78..77aa6198f5 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt @@ -6,9 +6,9 @@ set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; ) -add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) -target_compile_features(device_conv1d_fwd_instance PUBLIC) +add_library(device_conv1d_fwd_instance OBJECT ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) +# target_compile_features(device_conv1d_fwd_instance PUBLIC) set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) +# install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv1d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index 9288e40e56..a133300f73 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index d619ef4bf1..d7882a7d8b 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -6,9 +6,7 @@ set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; ) -add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) -target_compile_features(device_conv2d_bwd_data_instance PUBLIC) +add_library(device_conv2d_bwd_data_instance OBJECT ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv2d_bwd_data_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt index 6183e70b9b..7c384a882b 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_weight/CMakeLists.txt @@ -3,7 +3,7 @@ set(DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; ) -add_library(device_conv2d_bwd_weight_instance SHARED ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) +add_library(device_conv2d_bwd_weight_instance OBJECT ${DEVICE_CONV2D_BWD_WEIGHT_INSTANCE_SOURCE}) target_compile_features(device_conv2d_bwd_weight_instance PUBLIC) set_target_properties(device_conv2d_bwd_weight_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_conv2d_bwd_weight_instance LIBRARY DESTINATION lib) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 7483861524..1ef4a9b07e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -6,9 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -target_compile_features(device_conv2d_fwd_instance PUBLIC) +set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + +add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE}) + set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) +set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(device_conv2d_fwd_instance) +clang_tidy_check(device_convnd_2d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000..de98151ef8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,113 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..4b4a0fc25a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000..5603fc5d06 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000..b4447bcb82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 27a9736a3f..ad66c73bf8 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -2,9 +2,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) +add_library(device_conv2d_fwd_bias_relu_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv2d_fwd_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index d7bec82174..36b1f6c153 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -2,9 +2,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) +add_library(device_conv2d_fwd_bias_relu_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv2d_fwd_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt index c0942d5485..5906c7c5ac 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt @@ -3,9 +3,7 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; ) -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) -target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance OBJECT ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv2d_fwd_bias_relu_atomic_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt index f6849a7bb2..91a299c742 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt @@ -5,9 +5,8 @@ set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; ) -add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) +add_library(device_conv3d_fwd_instance OBJECT ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) target_compile_features(device_conv3d_fwd_instance PUBLIC) set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib) clang_tidy_check(device_conv3d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 745d26904a..bff51affd1 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt index 9ee961ad74..037f860808 100644 --- a/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/convnd_bwd_data/CMakeLists.txt @@ -14,7 +14,7 @@ set(DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; ) -add_library(device_convnd_bwd_data_instance SHARED ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) +add_library(device_convnd_bwd_data_instance OBJECT ${DEVICE_CONVND_BWD_DATA_INSTANCE_SOURCE}) target_compile_features(device_convnd_bwd_data_instance PUBLIC) set_target_properties(device_convnd_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_convnd_bwd_data_instance LIBRARY DESTINATION lib) diff --git a/library/src/tensor_operation_instance/gpu/device_conv2d.cpp b/library/src/tensor_operation_instance/gpu/device_conv2d.cpp new file mode 100644 index 0000000000..6b99433ffa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/device_conv2d.cpp @@ -0,0 +1,201 @@ +#include +#include "config.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "host_interface.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances); + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +struct DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl +{ + std::unique_ptr + MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) const + { + return el->MakeArgumentPointer(in_ptr, + wei_ptr, + out_ptr, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + } + std::unique_ptr MakeInvokerPointer() const + { + return el->MakeInvokerPointer(); + } + + std::string GetTypeString() { return el->GetTypeString(); } + bool IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg) + { + return el->IsSupportedArgument(arg); + } + + ck::tensor_operation::device::DeviceConvFwdPtr el; +}; + +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t() : pImpl(nullptr) {} +DeviceConvFwdPtr_t::~DeviceConvFwdPtr_t() = default; +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t&&) = default; +DeviceConvFwdPtr_t::DeviceConvFwdPtr_t(DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl& other) + : pImpl(std::make_unique(std::move(other))) +{ +} + +std::unique_ptr +DeviceConvFwdPtr_t::MakeArgumentPointer(void* in_ptr, + void* wei_ptr, + void* out_ptr, + size_t N, + size_t K, + size_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) const +{ + return pImpl->MakeArgumentPointer(in_ptr, + wei_ptr, + out_ptr, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); +} + +std::unique_ptr DeviceConvFwdPtr_t::MakeInvokerPointer() const +{ + return pImpl->MakeInvokerPointer(); +} + +std::string DeviceConvFwdPtr_t::GetTypeString() { return pImpl->GetTypeString(); } +bool DeviceConvFwdPtr_t::IsSupportedArgument(const DeviceConvFwdPtr_t::BaseArgument* arg_ptr) +{ + return pImpl->IsSupportedArgument(arg_ptr); +} + +using namespace ck::tensor_operation::device::device_conv2d_fwd_instance; +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); // Perhaps we can do better + } + return; +} + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t( + std::vector& instances) +{ + std::vector< + ck::tensor_operation::device::DeviceConvFwdPtr> + local_instances; + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(local_instances); + for(auto& kinder : local_instances) + { + DeviceConvFwdPtr_t::DeviceConvFwdPtrImpl tmp{std::move(kinder)}; + instances.emplace_back(tmp); + } + return; +} diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 5f057adcc5..8de1920bb3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,5 +1,8 @@ -# device_gemm_instance set(DEVICE_GEMM_INSTANCE_SOURCE + device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp; + device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; @@ -8,10 +11,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp; @@ -33,12 +36,21 @@ set(DEVICE_GEMM_INSTANCE_SOURCE device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp; + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp; + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp; ) -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_instance OBJECT ${DEVICE_GEMM_INSTANCE_SOURCE}) target_compile_features(device_gemm_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) - -clang_tidy_check(device_gemm_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..db7f6af04b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..c4253bcc4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d19d11f1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..cd86e5ceae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..3fcc5fdfdc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,45 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..8cd32128b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..4c4bfc440d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..c6077341b1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,46 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_f32_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_f32_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..91b68d4bf2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..13b185fd93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..ff4a89beb4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..e32158a292 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,42 @@ +#include +#include "config.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_dl_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + // #########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // #########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_dl_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp index 4530d95c72..2185b55aac 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp index 4214c71efb..90966349b2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp index 39bb7e1473..aa5a13001c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp similarity index 97% rename from library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp index 2ddde9e630..82eec1164a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[m, k] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances = // clang-format on >; -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector>& instances) { add_device_operation_instances(instances, - device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances{}); + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); } } // namespace device_gemm_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..fdc85dfc71 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..e400cd9bbb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_km_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..2f9241b93b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..537fe2bdae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp @@ -0,0 +1,54 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F64 = double; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, + DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> + // clang-format on + >; + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000..789c5b628f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -0,0 +1,14 @@ +# device_gemm_add_add_fastgelu_instance +set(DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; +) + +add_library(device_gemm_add_add_fastgelu_instance OBJECT ${DEVICE_GEMM_ADD_ADD_FASTGELU_INSTANCE_SOURCE}) + +target_compile_features(device_gemm_add_add_fastgelu_instance PUBLIC) +set_target_properties(device_gemm_add_add_fastgelu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +clang_tidy_check(device_gemm_add_add_fastgelu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..15ef0f00e8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,66 @@ +#include + +#include "config.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16 = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d) +// outout: e[m, n] +// input: a[k, m], b[k, n], d[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..54386e8a8a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +#include + +#include "config.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16 = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d) +// outout: e[m, n] +// input: a[k, m], b[n, k], d[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..b78fd155fa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,66 @@ +#include + +#include "config.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16 = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d) +// outout: e[m, n] +// input: a[m, k], b[k, n], d[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..4641cb40e0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,63 @@ +#include + +#include "config.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" +#include "device_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16 = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// e = elementwise((a * b), d) +// outout: e[m, n] +// input: a[m, k], b[n, k], d[m, n] +using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt index a0e5ba61a1..e2b0abb1d1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias2d/CMakeLists.txt @@ -10,9 +10,7 @@ set(DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; ) -add_library(device_gemm_bias2d_instance SHARED ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE}) -target_compile_features(device_gemm_bias2d_instance PUBLIC) +add_library(device_gemm_bias2d_instance OBJECT ${DEVICE_GEMM_BIAS2D_INSTANCE_SOURCE}) set_target_properties(device_gemm_bias2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_bias2d_instance LIBRARY DESTINATION lib) clang_tidy_check(device_gemm_bias2d_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt new file mode 100644 index 0000000000..0d068646af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -0,0 +1,10 @@ +set(DEVICE_GEMM_REDUCE_INSTANCE_SOURCE + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +) + +add_instance_library(device_gemm_bias_add_reduce_instance ${DEVICE_GEMM_REDUCE_INSTANCE_SOURCE}) +install(TARGETS device_gemm_bias_add_reduce_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_gemm_bias_add_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..2e1a7f531c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +#include +#include "config.hpp" +#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[k, n] +using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..db6140ea61 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -0,0 +1,81 @@ +#include +#include "config.hpp" +#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[k, m] * b[n, k] +using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..050473886f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -0,0 +1,81 @@ +#include +#include "config.hpp" +#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout| AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..c50e6cf83d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +#include +#include "config.hpp" +#include "device_gemm_bias_add_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[m, n] = a[m, k] * b[n, k] +using device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmBiasAddReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt index 69e05673d6..e2e7d4badd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu/CMakeLists.txt @@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -target_compile_features(device_gemm_bias_relu_instance PUBLIC) +add_library(device_gemm_bias_relu_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) clang_tidy_check(device_gemm_bias_relu_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt index 016bc4be2d..a10dbb555d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_relu_add/CMakeLists.txt @@ -6,9 +6,7 @@ set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) -target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) +add_library(device_gemm_bias_relu_add_instance OBJECT ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) clang_tidy_check(device_gemm_bias_relu_add_instance) diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp index 2f1509b6c8..e1d2f2f6ff 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,41 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[m, n] = a[k, m] * b[k, n] -// d0[m] = reduce0(c[m, n]) -// d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index c3e04287e4..81509a3fc5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,41 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[m, n] = a[k, m] * b[n, k] -// d0[m] = reduce0(c[m, n]) -// d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index e845c3bf82..4d13381d45 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,41 +21,50 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[m, n] = a[m, k] * b[n, k] -// d0[m] = reduce0(c[m, n]) -// d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index a356170789..459d0cd473 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -10,8 +10,9 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { -using F16 = ck::half_t; -using F32 = float; +using F16 = ck::half_t; +using F32 = float; +using DPtrsGlobal = ck::Tuple; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -20,38 +21,47 @@ template using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ReduceSum = ck::reduce::Add; -using Square = ck::tensor_operation::element_wise::UnarySquare; +using ReduceSum = ck::reduce::Add; +using ReduceOps = ck::Tuple; + +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // c[m, n] = a[m, k] * b[n, k] -// d0[m] = reduce0(c[m, n]) -// d1[m] = reduce1(c[m, n]) using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| D1EleOp| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| - //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| - //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| - //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, - DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSum, Square, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, PassThrough, PassThrough, PassThrough, ReduceOps, DInElementOps, DOutElementOps, ReduceMemOp, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> // clang-format on >; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances( - std::vector>& instances) + std::vector< + DeviceGemmReducePtr>& + instances) { add_device_operation_instances( instances, device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 8f591d8c49..6c5e31fddd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -6,7 +6,7 @@ set(DEVICE_GROUPED_GEMM_INSTANCE_SOURCE device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; ) -add_library(device_grouped_gemm_instance SHARED ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) +add_library(device_grouped_gemm_instance OBJECT ${DEVICE_GROUPED_GEMM_INSTANCE_SOURCE}) target_compile_features(device_grouped_gemm_instance PUBLIC) set_target_properties(device_grouped_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt index cced3a4b76..d566796c13 100644 --- a/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/reduce/CMakeLists.txt @@ -16,31 +16,14 @@ set(DEVICE_REDUCE_INSTANCE_SOURCE device_reduce_instance_threadwise_i8_i32_i8.cpp; device_reduce_instance_threadwise_i8_i8_i8.cpp; device_reduce_instance_threadwise_b16_f32_b16.cpp; - device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp; - device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp; - device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp; - device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp; - device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp; - device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp; - device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp; - device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp; device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp; device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp; device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp; + device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp; device_reduce_instance_multiblock_atomic_add_b16_f32_f32.cpp; - device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp; - device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp; - device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp; - device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp; - device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp; - device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp; - device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp; - device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp; ) -add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE}) -target_compile_features(device_reduce_instance PUBLIC) +add_library(device_reduce_instance OBJECT ${DEVICE_REDUCE_INSTANCE_SOURCE}) set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_reduce_instance LIBRARY DESTINATION lib) clang_tidy_check(device_reduce_instance) diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp deleted file mode 100644 index 82a9c11413..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp deleted file mode 100644 index 6b8139c32c..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_b16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1); - -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp deleted file mode 100644 index 267b9d4d9d..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, half_t, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp deleted file mode 100644 index 0036a89542..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp deleted file mode 100644 index 0512fa4158..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp deleted file mode 100644 index afe7f0752e..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp deleted file mode 100644 index 9cb3b8684f..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i32_i32_i8.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp deleted file mode 100644 index 8783a75486..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_second_call_i8_i8_i8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "device_reduce_instance_blockwise_second_call.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); -ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp new file mode 100644 index 0000000000..497f2695be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64.cpp @@ -0,0 +1,24 @@ +#include "device_reduce_instance_multiblock_atomic_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_reduce_instance { + +// clang-format off +// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); +ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); +// clang-format on + +} // namespace device_reduce_instance +} // namespace device +} // namespace tensor_operation + +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp deleted file mode 100644 index d740fcfa8f..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp deleted file mode 100644 index f57ed5ad86..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp deleted file mode 100644 index 724b364104..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp deleted file mode 100644 index 15028a0b4c..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp deleted file mode 100644 index ec0ba3cf8e..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp deleted file mode 100644 index 9ff2dcd93b..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); - -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2 -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1); - -// Will be moved to use MultiBlockAtomicAdd -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp deleted file mode 100644 index 0e37c2947f..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp b/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp deleted file mode 100644 index 4634faed06..0000000000 --- a/library/src/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "device_reduce_instance_multiblock_partial_reduce.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_reduce_instance { - -// clang-format off -// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); -ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); -// clang-format on - -} // namespace device_reduce_instance -} // namespace device -} // namespace tensor_operation - -} // namespace ck diff --git a/library/src/utility/CMakeLists.txt b/library/src/utility/CMakeLists.txt index 3580ba1a8f..0914855d59 100644 --- a/library/src/utility/CMakeLists.txt +++ b/library/src/utility/CMakeLists.txt @@ -8,14 +8,14 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility ) -set(CONV_FWD_UTIL_SOURCE - conv_fwd_util.cpp +set(CONV_UTIL_SOURCE + conv_util.cpp ) -add_library(conv_fwd_util SHARED ${CONV_FWD_UTIL_SOURCE}) -target_link_libraries(conv_fwd_util PRIVATE host_tensor) -target_compile_features(conv_fwd_util PUBLIC) -set_target_properties(conv_fwd_util PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_include_directories(conv_fwd_util SYSTEM PUBLIC $) +add_library(conv_util SHARED ${CONV_UTIL_SOURCE}) +target_link_libraries(conv_util PRIVATE host_tensor) +target_compile_features(conv_util PUBLIC) +set_target_properties(conv_util PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(conv_util SYSTEM PUBLIC $) -clang_tidy_check(conv_fwd_util) +clang_tidy_check(conv_util) diff --git a/library/src/utility/conv_fwd_util.cpp b/library/src/utility/conv_util.cpp similarity index 62% rename from library/src/utility/conv_fwd_util.cpp rename to library/src/utility/conv_util.cpp index 1658450388..a60d1a3495 100644 --- a/library/src/utility/conv_fwd_util.cpp +++ b/library/src/utility/conv_util.cpp @@ -1,5 +1,5 @@ -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" namespace ck { namespace utils { @@ -37,16 +37,16 @@ std::size_t get_flops(ck::index_t N, } ConvParams::ConvParams() - : num_dim_spatial(2), - N(128), - K(256), - C(192), - filter_spatial_lengths(2, 3), - input_spatial_lengths(2, 71), - conv_filter_strides(2, 2), - conv_filter_dilations(2, 1), - input_left_pads(2, 1), - input_right_pads(2, 1) + : num_dim_spatial_(2), + N_(128), + K_(256), + C_(192), + filter_spatial_lengths_(2, 3), + input_spatial_lengths_(2, 71), + conv_filter_strides_(2, 2), + conv_filter_dilations_(2, 1), + input_left_pads_(2, 1), + input_right_pads_(2, 1) { } @@ -60,22 +60,23 @@ ConvParams::ConvParams(ck::index_t n_dim, const std::vector& dilations, const std::vector& left_pads, const std::vector& right_pads) - : num_dim_spatial(n_dim), - N(n_batch), - K(n_out_channels), - C(n_in_channels), - filter_spatial_lengths(filters_len), - input_spatial_lengths(input_len), - conv_filter_strides(strides), - conv_filter_dilations(dilations), - input_left_pads(left_pads), - input_right_pads(right_pads) + : num_dim_spatial_(n_dim), + N_(n_batch), + K_(n_out_channels), + C_(n_in_channels), + filter_spatial_lengths_(filters_len), + input_spatial_lengths_(input_len), + conv_filter_strides_(strides), + conv_filter_dilations_(dilations), + input_left_pads_(left_pads), + input_right_pads_(right_pads) { - if(filter_spatial_lengths.size() != num_dim_spatial || - input_spatial_lengths.size() != num_dim_spatial || - conv_filter_strides.size() != num_dim_spatial || - conv_filter_dilations.size() != num_dim_spatial || - input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) + if(ck::type_convert(filter_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(input_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_strides_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_dilations_.size()) != num_dim_spatial_ || + ck::type_convert(input_left_pads_.size()) != num_dim_spatial_ || + ck::type_convert(input_right_pads_.size()) != num_dim_spatial_) { throw( std::runtime_error("ConvParams::GetOutputSpatialLengths: " @@ -85,26 +86,28 @@ ConvParams::ConvParams(ck::index_t n_dim, std::vector ConvParams::GetOutputSpatialLengths() const { - if(filter_spatial_lengths.size() != num_dim_spatial || - input_spatial_lengths.size() != num_dim_spatial || - conv_filter_strides.size() != num_dim_spatial || - conv_filter_dilations.size() != num_dim_spatial || - input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) + if(ck::type_convert(filter_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(input_spatial_lengths_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_strides_.size()) != num_dim_spatial_ || + ck::type_convert(conv_filter_dilations_.size()) != num_dim_spatial_ || + ck::type_convert(input_left_pads_.size()) != num_dim_spatial_ || + ck::type_convert(input_right_pads_.size()) != num_dim_spatial_) { throw( std::runtime_error("ConvParams::GetOutputSpatialLengths: " "parameter size is different from number of declared dimensions!")); } - std::vector out_spatial_len(num_dim_spatial, 0); - for(ck::index_t i = 0; i < num_dim_spatial; ++i) + std::vector out_spatial_len(num_dim_spatial_, 0); + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t idx_eff = (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; out_spatial_len[i] = - (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / - conv_filter_strides[i] + + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - idx_eff) / + conv_filter_strides_[i] + 1; } return out_spatial_len; @@ -114,40 +117,40 @@ ConvParams parse_conv_params(int num_dim_spatial, int arg_idx, char* const argv[ { ck::utils::conv::ConvParams params; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -226,12 +229,12 @@ HostTensorDescriptor get_input_host_tensor_descriptor(const std::vectorGetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * BatchCount * M * N * K; - std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N) * BatchCount; diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp index a6399c20d8..d1737f588a 100644 --- a/profiler/include/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -17,11 +17,20 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { +using F32 = float; +using F16 = ck::half_t; +using DPtrsGlobal = ck::Tuple; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::UnarySquare>; + DInElementOps, + DOutElementOps>; void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( std::vector&); @@ -53,7 +62,7 @@ template {-0.5, 0.5}, num_thread); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using D0ReduceOp = ck::reduce::Add; - using D1ReduceOp = ck::reduce::Add; - using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare; + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using DxsInElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - const auto d0_reduce_op = D0ReduceOp{}; - const auto d1_reduce_op = D1ReduceOp{}; - const auto d1_element_op = D1ElementOp{}; + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto dxs_in_element_op = DxsInElementOps{}; + const auto dxs_out_element_op = DxsOutElementOps{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; if(do_verification) { @@ -155,15 +168,15 @@ bool profile_batched_gemm_reduce_impl(int do_verification, { for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetReductionZeroVal(); - float d1_acc = d1_reduce_op.GetReductionZeroVal(); + float d0_acc = d0_reduce_op.GetIdentityValue(); + float d1_acc = d1_reduce_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { float d0_val = ck::type_convert(c_g_m_n_host_result(batch, m, n)); float d1_val; - d1_element_op(d1_val, d0_val); + UnarySquareElementOp{}(d1_val, d0_val); d0_reduce_op(d0_acc, d0_val); d1_reduce_op(d1_acc, d1_val); } @@ -180,6 +193,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification, DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + a_device_buf.ToDevice(a_g_m_k.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data()); @@ -241,8 +257,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer()), + &dxs_global, M, N, K, @@ -252,37 +267,20 @@ bool profile_batched_gemm_reduce_impl(int do_verification, a_element_op, b_element_op, c_element_op, - d1_element_op, + dxs_in_element_op, + dxs_out_element_op, BatchCount); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - // warm up - invoker_ptr->Run(argument_ptr.get()); + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); - // timing - float total_time = 0; - - for(int i = 0; i < nrepeat; ++i) - { - // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); - - KernelTimer timer; - - timer.Start(); - - invoker_ptr->Run(argument_ptr.get()); - - timer.End(); - - total_time += timer.GetElapsedTime(); - } - - float ave_time = total_time / nrepeat; + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::string gemm_name = gemm_ptr->GetTypeString(); diff --git a/profiler/include/profile_conv_bwd_data_impl.hpp b/profiler/include/profile_conv_bwd_data_impl.hpp deleted file mode 100644 index bec97e40f5..0000000000 --- a/profiler/include/profile_conv_bwd_data_impl.hpp +++ /dev/null @@ -1,283 +0,0 @@ -#pragma once - -#include "check_err.hpp" -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "device_conv_bwd_data.hpp" -#include "element_wise_operation.hpp" -#include "reference_conv_bwd_data.hpp" - -using F16 = ck::half_t; -using F32 = float; -using BF16 = ck::bhalf_t; -using INT8 = int8_t; -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv2d_bwd_data_instance { - -using DeviceConvBwdDataNoOpPtr = - DeviceConvBwdDataPtr; -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances( - std::vector&); -void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( - std::vector&); -} // namespace device_conv2d_bwd_data_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace profiler { - -template -void profile_conv_bwd_data_impl(int do_verification, - int init_method, - bool do_log, - int nrepeat, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) -{ - const ck::index_t Y = filter_spatial_lengths[0]; - const ck::index_t X = filter_spatial_lengths[1]; - - const ck::index_t Hi = input_spatial_lengths[0]; - const ck::index_t Wi = input_spatial_lengths[1]; - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - auto f_host_tensor_descriptor = - [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W, auto layout) { - if constexpr(is_same::value || - is_same::value || - is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(is_same::value || - is_same::value || - is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi_host_result(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor in_n_c_hi_wi_device_result( - f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - using InElementOp = ck::tensor_operation::element_wise::PassThrough; - using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; - using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - - const auto in_element_op = InElementOp{}; - const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{}; - - if(do_verification) - { - using ReferenceConvBwdDataInstance = - ck::tensor_operation::host::ReferenceConvBwdData; - - auto ref_conv = ReferenceConvBwdDataInstance{}; - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi_host_result, - wei_k_c_y_x, - out_n_k_ho_wo, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op); - - ref_invoker.Run(ref_argument); - } - - DeviceMem in_device_buf(sizeof(InDataType) * - in_n_c_hi_wi_device_result.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - out_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - using DeviceConvBwdDataNoOpPtr = - ck::tensor_operation::device::DeviceConvBwdDataPtr; - - // add device Conv instances - std::vector conv_ptrs; - if constexpr(ck::is_same_v, float> && - ck::is_same_v, float> && - ck::is_same_v, float>) - { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t>) - { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, ck::bhalf_t> && - ck::is_same_v, ck::bhalf_t> && - ck::is_same_v, ck::bhalf_t>) - { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, int8_t> && - ck::is_same_v, int8_t> && - ck::is_same_v, int8_t>) - { - ck::tensor_operation::device::device_conv2d_bwd_data_instance:: - add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); - } - - if(conv_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - std::string best_conv_name; - float best_ave_time = 0; - float best_tflops = 0; - float best_gb_per_sec = 0; - - // profile device Conv instances - for(auto& conv_ptr : conv_ptrs) - { - auto argument_ptr = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - in_element_op, - wei_element_op, - out_element_op); - - auto invoker_ptr = conv_ptr->MakeInvokerPointer(); - - if(conv_ptr->IsSupportedArgument(argument_ptr.get())) - { - std::string conv_name = conv_ptr->GetTypeString(); - - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << conv_name << std::endl; - - if(tflops > best_tflops) - { - best_conv_name = conv_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - - ck::utils::check_err(in_n_c_hi_wi_device_result.mData, - in_n_c_hi_wi_host_result.mData); - - if(do_log) - { - LogRangeAsType(std::cout << "in : ", out_n_k_ho_wo.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "wei: ", wei_k_c_y_x.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_host : ", in_n_c_hi_wi_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "out_device: ", in_n_c_hi_wi_device_result.mData, ",") - << std::endl; - } - } - } - } - - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; -} - -} // namespace profiler -} // namespace ck diff --git a/profiler/include/profile_conv_bwd_weight_impl.hpp b/profiler/include/profile_conv_bwd_weight_impl.hpp index 20fe0ef549..8e3a4074b0 100644 --- a/profiler/include/profile_conv_bwd_weight_impl.hpp +++ b/profiler/include/profile_conv_bwd_weight_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "stream_config.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -43,7 +45,7 @@ template MakeArgumentPointer( static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), @@ -214,7 +218,8 @@ bool profile_conv_bwd_weight_impl(int do_verification, { std::string conv_name = conv_ptr->GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; @@ -242,6 +247,7 @@ bool profile_conv_bwd_weight_impl(int do_verification, wei_device_buf.FromDevice(wei_k_c_y_x_device_result.mData.data()); float max_error = check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); + if(max_error > 8) { pass = false; diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp index d0de7307d2..5ea35cd72f 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -42,7 +42,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; diff --git a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp index 9bdfa61283..f1c2fd300a 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp @@ -119,7 +119,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp index f34e52048e..eeb2b93e4e 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -41,7 +41,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp index 4f9038a72b..291bf2abc0 100644 --- a/profiler/include/profile_convnd_bwd_data_impl.hpp +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -1,7 +1,7 @@ #pragma once #include "config.hpp" #include "device.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "tensor_layout.hpp" @@ -222,7 +222,7 @@ static bool check_out(const Tensor& ref, const Tensor& result) { float max_diff = 1e-6; - for(int i = 0; i < ref.mData.size(); ++i) + for(std::size_t i = 0; i < ref.mData.size(); ++i) { float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); if(max_diff < diff) @@ -236,16 +236,16 @@ template void show_data_nhwc_layout(Tensor& nhwc) { std::cout << "["; - for(int n = 0; n < nhwc.mDesc.GetLengths()[0]; n++) + for(int n = 0; n < ck::type_convert(nhwc.mDesc.GetLengths()[0]); n++) { std::cout << "["; - for(int hi = 0; hi < nhwc.mDesc.GetLengths()[2]; hi++) + for(int hi = 0; hi < ck::type_convert(nhwc.mDesc.GetLengths()[2]); hi++) { std::cout << "["; - for(int wi = 0; wi < nhwc.mDesc.GetLengths()[3]; wi++) + for(int wi = 0; wi < ck::type_convert(nhwc.mDesc.GetLengths()[3]); wi++) { std::cout << "["; - for(int c = 0; c < nhwc.mDesc.GetLengths()[1]; c++) + for(int c = 0; c < ck::type_convert(nhwc.mDesc.GetLengths()[1]); c++) { std::cout << static_cast(nhwc(n, c, hi, wi)) << " "; } @@ -269,7 +269,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); diff --git a/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp new file mode 100644 index 0000000000..748c9ada80 --- /dev/null +++ b/profiler/include/profile_gemm_add_add_fastgelu_impl.hpp @@ -0,0 +1,288 @@ +#pragma once + +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "device_gemm_multiple_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr< + 2, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddAddFastGelu>; + +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +int profile_gemm_add_add_fastgelu_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideE) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl; + std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddAddFastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + // add device GEMM instances + std::vector + device_op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + device_op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + device_op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + device_op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + device_op_ptrs); + } + } + + std::cout << "found " << device_op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n(HostTensorDescriptor( + std::vector{static_cast(M), static_cast(N)})); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + + std::string best_device_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& device_op_ptr : device_op_ptrs) + { + auto argument_ptr = device_op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d0_m_n_device_buf.GetDeviceBuffer(), + d1_m_n_device_buf.GetDeviceBuffer()}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD0, StrideD1}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = device_op_ptr->MakeInvokerPointer(); + + std::string device_op_name = device_op_ptr->GetTypeString(); + + if(device_op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << device_op_name << std::endl; + + if(tflops > best_tflops) + { + best_device_op_name = device_op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && + ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData); + } + } + else + { + std::cout << device_op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_device_op_name << std::endl; + + return pass ? 0 : 1; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp index 98e4ad76c9..8565f9637c 100644 --- a/profiler/include/profile_gemm_bias_2d_impl.hpp +++ b/profiler/include/profile_gemm_bias_2d_impl.hpp @@ -65,7 +65,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/include/profile_gemm_bias_add_reduce_impl.hpp b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp new file mode 100644 index 0000000000..5b792219c0 --- /dev/null +++ b/profiler/include/profile_gemm_bias_add_reduce_impl.hpp @@ -0,0 +1,386 @@ +#pragma once +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "reduction_operator.hpp" +#include "device_gemm_reduce.hpp" +#include "reference_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F32 = float; +using F16 = ck::half_t; +using DPtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + +using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + DInElementOps, + DOutElementOps>; + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + std::vector&); + +void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_add_reduce_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideC1) +{ + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + + Tensor c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor bias_n(f_host_tensor_descriptor1d(N, 1)); + Tensor c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor d0_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_host_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + Tensor c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor d0_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + Tensor d1_m_device_result( + HostTensorDescriptor(std::vector({static_cast(M)}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; + std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; + + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + bias_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = PassThrough; + using C1ElementOp = PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using DxsInElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto c1_element_op = C1ElementOp{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{N, N}; + + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = DDataType; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + for(int n = 0; n < N; ++n) + { + ReduceAccDataType acc = static_cast(c_m_n_host_result(m, n)) + + static_cast(bias_n(n)); + + ReduceAccDataType c1 = static_cast(c1_m_n(m, n)); + c_element_op(acc, acc); + c1_element_op(c1, c1); + acc += c1; + c_m_n_host_result(m, n) = static_cast(acc); + } + + for(int m = 0; m < M; ++m) + { + auto d0_acc = d0_reduce_op.GetIdentityValue(); + auto d1_acc = d1_reduce_op.GetIdentityValue(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType c_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; + + dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); + dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); + d0_reduce_op(d0_acc, d0_val); + d1_reduce_op(d1_acc, d1_val); + } + + dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); + dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); + d0_m_host_result(m) = ck::type_convert(d0_acc); + d1_m_host_result(m) = ck::type_convert(d1_acc); + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace()); + DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + bias_device_buf.ToDevice(bias_n.mData.data()); + c1_device_buf.ToDevice(c1_m_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(bias_device_buf.GetDeviceBuffer()), + static_cast(c1_device_buf.GetDeviceBuffer()), + &dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + dxs_in_element_op, + dxs_out_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; + + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N + + sizeof(C1DataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_m_device_result.mData.data()); + + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData); + ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_device: ", d0_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_m_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_device: ", d1_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_add_impl.hpp b/profiler/include/profile_gemm_bias_relu_add_impl.hpp index 75ed78075b..6fec17c199 100644 --- a/profiler/include/profile_gemm_bias_relu_add_impl.hpp +++ b/profiler/include/profile_gemm_bias_relu_add_impl.hpp @@ -48,7 +48,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/include/profile_gemm_bias_relu_impl.hpp b/profiler/include/profile_gemm_bias_relu_impl.hpp index 0735f3c31b..69010becc5 100644 --- a/profiler/include/profile_gemm_bias_relu_impl.hpp +++ b/profiler/include/profile_gemm_bias_relu_impl.hpp @@ -48,7 +48,7 @@ template GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 93262fe802..a3400f89b3 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,5 +1,7 @@ #pragma once #include +#include +#include #include "check_err.hpp" #include "config.hpp" @@ -42,14 +44,10 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( std::vector&); @@ -74,6 +72,21 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); + } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation @@ -85,13 +98,14 @@ namespace profiler { template void profile_gemm_impl(int do_verification, int init_method, bool do_log, - int nrepeat, + bool time_kernel, int M, int N, int K, @@ -125,7 +139,11 @@ void profile_gemm_impl(int do_verification, std::size_t num_thread = 1; switch(init_method) { - case 0: break; + // case 0: break; + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); @@ -174,6 +192,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs); } @@ -192,6 +213,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs); } @@ -210,6 +234,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(gemm_ptrs); } @@ -228,6 +255,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(gemm_ptrs); } @@ -250,6 +280,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } @@ -268,6 +301,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); @@ -289,6 +325,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } @@ -307,6 +346,9 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } @@ -353,28 +395,40 @@ void profile_gemm_impl(int do_verification, is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && is_same::value) { ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemm_ptrs); + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemm_ptrs); } } @@ -416,7 +470,8 @@ void profile_gemm_impl(int do_verification, std::string gemm_name = gemm_ptr->GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = std::size_t(2) * M * N * K; @@ -457,8 +512,14 @@ void profile_gemm_impl(int do_verification, bf16_to_f32_(b_k_n, b_f32_k_n); bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result); - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -490,6 +551,7 @@ void profile_gemm_impl(int do_verification, ck::tensor_operation::host::ReferenceGemm; @@ -522,12 +584,50 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << "does not support this GEMM problem" << std::endl; + std::cout << gemm_ptr->GetTypeString() << " does not support this GEMM problem" + << std::endl; } } - std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " - << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " : " << best_ave_time + << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, " + << best_gemm_name << std::endl; } } // namespace profiler diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index 6ef3e010b1..97c23defe0 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -1,4 +1,5 @@ #pragma once +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -16,11 +17,21 @@ namespace tensor_operation { namespace device { namespace device_gemm_instance { +using F32 = float; +using F16 = ck::half_t; +using DPtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryDivide; +using Identity = ck::tensor_operation::element_wise::PassThrough; +using Square = ck::tensor_operation::element_wise::UnarySquare; +using DInElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; + using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::UnarySquare>; + DInElementOps, + DOutElementOps>; void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( std::vector&); @@ -52,7 +63,7 @@ template {-0.5, 0.5}, num_thread); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using D0ReduceOp = ck::reduce::Add; - using D1ReduceOp = ck::reduce::Add; - using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare; + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; + using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; + using DxsInElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - const auto d0_reduce_op = D0ReduceOp{}; - const auto d1_reduce_op = D1ReduceOp{}; - const auto d1_element_op = D1ElementOp{}; + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{N, N}; if(do_verification) { - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + using ReduceAccDataType = DDataType; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); @@ -141,19 +165,24 @@ bool profile_gemm_reduce_impl(int do_verification, for(int m = 0; m < M; ++m) { - float d0_acc = d0_reduce_op.GetReductionZeroVal(); - float d1_acc = d1_reduce_op.GetReductionZeroVal(); + auto d0_acc = d0_reduce_op.GetIdentityValue(); + auto d1_acc = d1_reduce_op.GetIdentityValue(); for(int n = 0; n < N; ++n) { - float d0_val = ck::type_convert(c_m_n_host_result(m, n)); - float d1_val; + ReduceAccDataType c_val = + ck::type_convert(c_m_n_host_result(m, n)); + ReduceAccDataType d0_val; + ReduceAccDataType d1_val; - d1_element_op(d1_val, d0_val); + dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); + dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); d0_reduce_op(d0_acc, d0_val); d1_reduce_op(d1_acc, d1_val); } + dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); + dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); d0_m_host_result(m) = ck::type_convert(d0_acc); d1_m_host_result(m) = ck::type_convert(d1_acc); } @@ -165,6 +194,9 @@ bool profile_gemm_reduce_impl(int do_verification, DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); + auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer())); + a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); @@ -226,8 +258,7 @@ bool profile_gemm_reduce_impl(int do_verification, gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), static_cast(b_device_buf.GetDeviceBuffer()), static_cast(c_device_buf.GetDeviceBuffer()), - static_cast(d0_device_buf.GetDeviceBuffer()), - static_cast(d1_device_buf.GetDeviceBuffer()), + &dxs_global, M, N, K, @@ -237,42 +268,25 @@ bool profile_gemm_reduce_impl(int do_verification, a_element_op, b_element_op, c_element_op, - d1_element_op); + dxs_in_element_op, + dxs_out_element_op); auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - // warm up - invoker_ptr->Run(argument_ptr.get()); + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); - // timing - float total_time = 0; - - for(int i = 0; i < nrepeat; ++i) - { - // init DO, D1 to 0 - d0_device_buf.SetZero(); - d1_device_buf.SetZero(); - - KernelTimer timer; - - timer.Start(); - - invoker_ptr->Run(argument_ptr.get()); - - timer.End(); - - total_time += timer.GetElapsedTime(); - } - - float ave_time = total_time / nrepeat; + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::string gemm_name = gemm_ptr->GetTypeString(); std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + sizeof(CDataType) * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -296,13 +310,9 @@ bool profile_gemm_reduce_impl(int do_verification, d0_device_buf.FromDevice(d0_m_device_result.mData.data()); d1_device_buf.FromDevice(d1_m_device_result.mData.data()); - float c_error = check_error(c_m_n_host_result, c_m_n_device_result); - float d0_error = check_error(d0_m_host_result, d0_m_device_result); - float d1_error = check_error(d1_m_host_result, d1_m_device_result); - - pass = pass && (c_error < 1E-6); - pass = pass && (d0_error < 1E-6); - pass = pass && (d1_error < 1E-6); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData); + ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp index cced480c36..8806e8ff43 100644 --- a/profiler/include/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -43,19 +43,20 @@ namespace profiler { template void profile_grouped_gemm_impl(int do_verification, int init_method, bool do_log, - int nrepeat, - std::vector Ms, - std::vector Ns, - std::vector Ks, - std::vector StrideAs, - std::vector StrideBs, - std::vector StrideCs) + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs) { auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -71,7 +72,7 @@ void profile_grouped_gemm_impl(int do_verification, } }; - int group_count = Ms.size(); + std::size_t group_count = Ms.size(); if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && group_count == StrideBs.size() && group_count == StrideCs.size())) @@ -83,7 +84,7 @@ void profile_grouped_gemm_impl(int do_verification, std::vector> b_k_n; std::vector> c_m_n_device_results; - for(int i = 0; i < Ms.size(); i++) + for(std::size_t i = 0; i < group_count; i++) { a_m_k.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); @@ -144,7 +145,7 @@ void profile_grouped_gemm_impl(int do_verification, gemm_shapes.reserve(group_count); - for(int i = 0; i < group_count; i++) + for(std::size_t i = 0; i < group_count; i++) { a_device_buf.emplace_back( std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace())); @@ -231,10 +232,11 @@ void profile_grouped_gemm_impl(int do_verification, { std::string gemm_name = gemm_ptr->GetTypeString(); - float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t flop = 0, num_btype = 0; - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; @@ -258,7 +260,7 @@ void profile_grouped_gemm_impl(int do_verification, if(do_verification) { - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); @@ -270,6 +272,7 @@ void profile_grouped_gemm_impl(int do_verification, ck::tensor_operation::host::ReferenceGemm; diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp index 678134f60b..5e192aa1bc 100644 --- a/profiler/include/profile_reduce_impl.hpp +++ b/profiler/include/profile_reduce_impl.hpp @@ -5,74 +5,77 @@ #include "device_reduce_instance.hpp" #include "reduction_enums.hpp" #include "host_reduction.hpp" +#include "host_common_util.hpp" +#include "host_tensor_generator.hpp" namespace ck { namespace tensor_operation { namespace device { namespace device_reduce_instance { -template +template struct ReduceDescription { static constexpr int Rank_ = Rank; static constexpr int NumReduceDim_ = NumReduceDim; static constexpr int ReduceOpId_ = ReduceOpId; - static constexpr int NanOpt_ = NanOpt; - static constexpr int IndicesOpt_ = IndicesOpt; + static constexpr int PropagateNan_ = PropagateNan; + static constexpr int UseIndex_ = UseIndex; }; -using reduce_description_instances = std::tuple, // for ADD - ReduceDescription<4, 4, 0, 0, 0>, - ReduceDescription<4, 1, 0, 0, 0>, - ReduceDescription<2, 1, 0, 0, 0>, +using reduce_description_instances = + std::tuple, // for ADD + ReduceDescription<4, 4, 0, false, false>, + ReduceDescription<4, 1, 0, false, false>, + ReduceDescription<2, 1, 0, false, false>, - ReduceDescription<4, 3, 5, 0, 0>, // for AVG - ReduceDescription<4, 4, 5, 0, 0>, - ReduceDescription<4, 1, 5, 0, 0>, - ReduceDescription<2, 1, 5, 0, 0>, + ReduceDescription<4, 3, 5, false, false>, // for AVG + ReduceDescription<4, 4, 5, false, false>, + ReduceDescription<4, 1, 5, false, false>, + ReduceDescription<2, 1, 5, false, false>, - ReduceDescription<4, 3, 7, 0, 0>, // for NORM2 - ReduceDescription<4, 4, 7, 0, 0>, - ReduceDescription<4, 1, 7, 0, 0>, - ReduceDescription<2, 1, 7, 0, 0>, + ReduceDescription<4, 3, 7, false, false>, // for NORM2 + ReduceDescription<4, 4, 7, false, false>, + ReduceDescription<4, 1, 7, false, false>, + ReduceDescription<2, 1, 7, false, false>, - ReduceDescription<4, 3, 2, 0, 0>, // for MIN - ReduceDescription<4, 4, 2, 0, 0>, - ReduceDescription<4, 1, 2, 0, 0>, - ReduceDescription<2, 1, 2, 0, 0>, - ReduceDescription<4, 3, 3, 0, 0>, // for MAX - ReduceDescription<4, 4, 3, 0, 0>, - ReduceDescription<4, 1, 3, 0, 0>, - ReduceDescription<2, 1, 3, 0, 0>, - ReduceDescription<4, 3, 4, 0, 0>, // for AMAX - ReduceDescription<4, 4, 4, 0, 0>, - ReduceDescription<4, 1, 4, 0, 0>, - ReduceDescription<2, 1, 4, 0, 0>, + ReduceDescription<4, 3, 2, false, false>, // for MIN + ReduceDescription<4, 4, 2, false, false>, + ReduceDescription<4, 1, 2, false, false>, + ReduceDescription<2, 1, 2, false, false>, + ReduceDescription<4, 3, 3, false, false>, // for MAX + ReduceDescription<4, 4, 3, false, false>, + ReduceDescription<4, 1, 3, false, false>, + ReduceDescription<2, 1, 3, false, false>, + ReduceDescription<4, 3, 4, false, false>, // for AMAX + ReduceDescription<4, 4, 4, false, false>, + ReduceDescription<4, 1, 4, false, false>, + ReduceDescription<2, 1, 4, false, false>, - ReduceDescription<4, 3, 2, 0, 1>, // for MIN - ReduceDescription<4, 4, 2, 0, 1>, - ReduceDescription<4, 1, 2, 0, 1>, - ReduceDescription<2, 1, 2, 0, 1>, - ReduceDescription<4, 3, 3, 0, 1>, // for MAX - ReduceDescription<4, 4, 3, 0, 1>, - ReduceDescription<4, 1, 3, 0, 1>, - ReduceDescription<2, 1, 3, 0, 1>, - ReduceDescription<4, 3, 4, 0, 1>, // for AMAX - ReduceDescription<4, 4, 4, 0, 1>, - ReduceDescription<4, 1, 4, 0, 1>, - ReduceDescription<2, 1, 4, 0, 1>>; + ReduceDescription<4, 3, 2, false, true>, // for MIN + ReduceDescription<4, 4, 2, false, true>, + ReduceDescription<4, 1, 2, false, true>, + ReduceDescription<2, 1, 2, false, true>, + ReduceDescription<4, 3, 3, false, true>, // for MAX + ReduceDescription<4, 4, 3, false, true>, + ReduceDescription<4, 1, 3, false, true>, + ReduceDescription<2, 1, 3, false, true>, + ReduceDescription<4, 3, 4, false, true>, // for AMAX + ReduceDescription<4, 4, 4, false, true>, + ReduceDescription<4, 1, 4, false, true>, + ReduceDescription<2, 1, 4, false, true>>; template bool description_match(const DescriptionType& description, int Rank, const std::vector& reduceDims, ReduceTensorOp ReduceOpId, - NanPropagation NanOpt, - ReduceTensorIndices IndicesOpt) + bool PropagateNan, + bool UseIndex) { if(description.Rank_ != Rank || description.ReduceOpId_ != static_cast(ReduceOpId) || - description.NanOpt_ != static_cast(NanOpt) || - description.IndicesOpt_ != static_cast(IndicesOpt)) + description.PropagateNan_ != static_cast(PropagateNan) || + description.UseIndex_ != static_cast(UseIndex)) return (false); if(DescriptionType::NumReduceDim_ != reduceDims.size()) @@ -116,48 +119,18 @@ static inline std::vector get_invariant_dims(const std::vector& reduce return invariantDims; }; -template -static void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) -{ - std::ofstream outFile(fileName, std::ios::binary); - if(outFile) - { - outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); - outFile.close(); - std::cout << "Write output to file " << fileName << std::endl; - } - else - { - std::cout << "Could not open file " << fileName << " for writing" << std::endl; - } -}; - -// map the data type used by the GPU kernels to the corresponding type used by the host codes -template -struct type_mapping -{ - using OutType = InType; -}; - -template <> -struct type_mapping -{ - using OutType = half_float::half; -}; - template -void profile_reduce_impl_impl(bool do_verification, + bool PropagateNan, + bool UseIndex> +bool profile_reduce_impl_impl(bool do_verification, int init_method, - bool do_log, bool do_dumpout, - int nrepeat, + bool time_kernel, const std::vector& inLengths, const std::vector& reduceDims, float alpha, @@ -165,16 +138,13 @@ void profile_reduce_impl_impl(bool do_verification, { using namespace ck::tensor_operation::device; using namespace ck::tensor_operation::device::device_reduce_instance; - using namespace ck::host_reduce; + using ck::host_common::dumpBufferToFile; constexpr bool op_support_indices = (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || ReduceOpId == ReduceTensorOp::AMAX); - constexpr bool NeedIndices = - (op_support_indices && (IndicesOpt != ReduceTensorIndices::NO_INDICES)); - - constexpr bool PropagateNan = (NanOpt == NanPropagation::PROPAGATE_NAN); + constexpr bool OutputIndex = (op_support_indices && UseIndex); constexpr bool out_support_atomic_add = std::is_same::value; constexpr bool op_support_atomic_add = @@ -195,8 +165,7 @@ void profile_reduce_impl_impl(bool do_verification, (op_support_indices && !std::is_same::value); // 1) The indices can only be used when the reduction operation is indexable - constexpr bool invalid_reduce_3 = - (!op_support_indices && IndicesOpt != ReduceTensorIndices::NO_INDICES); + constexpr bool invalid_reduce_3 = (!op_support_indices && UseIndex); // 1) If InDataType is int8_t, must use int8_t as AccDataType for indexable reduction operations // 2) If InDataType is int8_t, must use int32_t as AccDataType for non-indexable reduction @@ -219,6 +188,8 @@ void profile_reduce_impl_impl(bool do_verification, constexpr bool invalid_reduce = (invalid_reduce_1 || invalid_reduce_2 || invalid_reduce_3 || invalid_reduce_4 || invalid_reduce_5 || invalid_reduce_6); + bool pass = true; + if constexpr(!invalid_reduce) { Tensor in(inLengths); @@ -282,42 +253,31 @@ void profile_reduce_impl_impl(bool do_verification, if(beta != 0.0f) out_dev.ToDevice(out.mData.data()); - size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int) : 0; + size_t indicesSizeInBytes = OutputIndex ? out.mDesc.GetElementSize() * sizeof(int) : 0; DeviceMem out_indices_dev(indicesSizeInBytes); float best_avg_time = 0; float best_gb_per_sec = 0; - using InElementwiseOperation_0 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_0 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_1 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_1 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_2 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_2 = - typename reduce_unary_operator:: - AccElementwiseOperation; + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + using ReduceOperation = typename reduce_binary_operator::opType; + + InElementwiseOperation in_elementwise_op; + AccElementwiseOperation acc_elementwise_op; + + std::tie(in_elementwise_op, acc_elementwise_op) = + reduce_unary_operator::GetElementwiseOperator( + static_cast(reduce_total_length)); using DeviceReduceInstPtr0 = - DeviceReducePtr; - using DeviceReduceInstPtr1 = - DeviceReducePtr; - using DeviceReduceInstPtr2 = - DeviceReducePtr; + DeviceReducePtr; std::vector reduce0_ptrs; - std::vector reduce1_ptrs; - std::vector reduce2_ptrs; add_device_reduce_instance_threadwise(reduce0_ptrs); + PropagateNan, + UseIndex>(reduce0_ptrs); add_device_reduce_instance_blockwise(reduce0_ptrs); + PropagateNan, + UseIndex>(reduce0_ptrs); if constexpr(use_atomic_add) { @@ -345,35 +305,11 @@ void profile_reduce_impl_impl(bool do_verification, Rank, NumReduceDim, ReduceOpId, - NanOpt, - IndicesOpt>(reduce0_ptrs); + PropagateNan, + UseIndex>(reduce0_ptrs); } - else - { - add_device_reduce_instance_multiblock_partial_reduce(reduce1_ptrs); - }; - // used for secondary reduction - if constexpr(!use_atomic_add) - { - add_device_reduce_instance_blockwise_second_call(reduce2_ptrs); - }; - - if(reduce0_ptrs.empty() && reduce1_ptrs.empty()) + if(reduce0_ptrs.empty()) { throw std::runtime_error("Wrong! No device REDUCE instance found"); }; @@ -383,32 +319,36 @@ void profile_reduce_impl_impl(bool do_verification, ReductionHost + OutputIndex> hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - hostReduce.Run( - alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); + hostReduce.Run(alpha, + in.mData.data(), + beta, + out_ref.mData.data(), + out_indices_ref.mData.data(), + in_elementwise_op, + acc_elementwise_op); }; - const auto i_inLengths = to_int_vector(inLengths); - const auto i_inStrides = to_int_vector(inStrides); - const auto i_outLengths = to_int_vector(outLengths); - const auto i_outStrides = to_int_vector(outStrides); + std::vector i_inLengths; + std::vector i_inStrides; + std::vector i_outLengths; + std::vector i_outStrides; + + i_inLengths.assign(inLengths.begin(), inLengths.end()); + i_inStrides.assign(inStrides.begin(), inStrides.end()); + i_outLengths.assign(outLengths.begin(), outLengths.end()); + i_outStrides.assign(outStrides.begin(), outStrides.end()); for(auto& reduce_ptr : reduce0_ptrs) { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_0 in_elementwise_op_0(static_cast(reduce_total_length)); - AccElementwiseOperation_0 acc_elementwise_op_0( - static_cast(reduce_total_length)); - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, i_inStrides, i_outLengths, @@ -417,11 +357,11 @@ void profile_reduce_impl_impl(bool do_verification, alpha, beta, in_dev.GetDeviceBuffer(), + nullptr, out_dev.GetDeviceBuffer(), out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_0, - acc_elementwise_op_0); + in_elementwise_op, + acc_elementwise_op); if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) continue; @@ -430,7 +370,8 @@ void profile_reduce_impl_impl(bool do_verification, auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - float avg_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + float avg_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); std::size_t num_bytes = invariant_total_length * reduce_total_length * sizeof(InDataType) + @@ -438,8 +379,9 @@ void profile_reduce_impl_impl(bool do_verification, float gb_per_sec = num_bytes / 1.E6 / avg_time; - std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " << reduce_name - << std::endl; + if(time_kernel) + std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s, " + << reduce_name << std::endl; if(gb_per_sec > best_gb_per_sec) { @@ -449,22 +391,24 @@ void profile_reduce_impl_impl(bool do_verification, if(do_verification) { - out_dev.FromDevice(out.mData.data()); - ck::utils::check_err(out.mData, out_ref.mData); + bool single_pass; - if(NeedIndices) + out_dev.FromDevice(out.mData.data()); + single_pass = ck::utils::check_err(out.mData, out_ref.mData); + + if(OutputIndex) { out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::check_err(out_indices.mData, out_indices_ref.mData); - ; + single_pass = single_pass && + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); }; - if(do_log) + if(!single_pass) { - LogRangeAsType(std::cout << "out_host : ", out_ref.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "out_device: ", out.mData, ",") << std::endl; - }; + std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; + } + + pass = pass && single_pass; }; if(do_dumpout) @@ -473,7 +417,7 @@ void profile_reduce_impl_impl(bool do_verification, dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize()); dumpBufferToFile( "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize()); - if(NeedIndices) + if(OutputIndex) { dumpBufferToFile("dump_indices.bin", out_indices.mData.data(), @@ -485,156 +429,34 @@ void profile_reduce_impl_impl(bool do_verification, }; }; - for(auto& reduce_ptr : reduce1_ptrs) - { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_1 in_elementwise_op_1(static_cast(reduce_total_length)); - AccElementwiseOperation_1 acc_elementwise_op_1( - static_cast(reduce_total_length)); - - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_1, - acc_elementwise_op_1); - - if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) - continue; - - std::string reduce_name = reduce_ptr->GetTypeString(); - - auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - - float avg_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); - - std::size_t num_bytes = - invariant_total_length * reduce_total_length * sizeof(InDataType) + - invariant_total_length * sizeof(OutDataType); - - std::vector inLengths2 = reduce_ptr->GetWorkspace2dLengths(argument_ptr.get()); - std::vector inStrides2{inLengths2[1], 1}; - - for(auto& reduce2_ptr : reduce2_ptrs) - { - InElementwiseOperation_2 in_elementwise_op_2( - static_cast(reduce_total_length)); - AccElementwiseOperation_2 acc_elementwise_op_2( - static_cast(reduce_total_length)); - - auto argument2_ptr = - reduce2_ptr->MakeArgumentPointer(inLengths2, - inStrides2, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - ws_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_2, - acc_elementwise_op_2); - - if(!reduce2_ptr->IsSupportedArgument(argument2_ptr.get())) - continue; - - std::string reduce2_name = reduce2_ptr->GetTypeString(); - - auto invoker2_ptr = reduce2_ptr->MakeInvokerPointer(); - - float avg_time_2 = invoker2_ptr->Run(argument2_ptr.get(), nrepeat); - - std::size_t num_bytes_2 = - static_cast(inLengths2[0]) * inLengths2[1] * sizeof(AccDataType); - - float gb_per_sec = (num_bytes + num_bytes_2) / 1.E6 / (avg_time + avg_time_2); - - std::cout << "Perf: " << (avg_time + avg_time_2) << " ms, " << gb_per_sec - << " GB/s, " << reduce_name << " => " << reduce2_name << std::endl; - - if(gb_per_sec > best_gb_per_sec) - { - best_avg_time = avg_time + avg_time_2; - best_gb_per_sec = gb_per_sec; - } - - if(do_verification) - { - out_dev.FromDevice(out.mData.data()); - ck::utils::check_err(out.mData, out_ref.mData); - - if(NeedIndices) - { - out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::check_err(out_indices.mData, out_indices_ref.mData); - ; - }; - - if(do_log) - { - LogRangeAsType(std::cout << "out_host : ", out_ref.mData, ",") - << std::endl; - LogRangeAsType(std::cout << "out_device: ", out.mData, ",") - << std::endl; - } - } - - if(do_dumpout) - { - dumpBufferToFile("dump_in.bin", in.mData.data(), in.mDesc.GetElementSize()); - dumpBufferToFile("dump_out.bin", out.mData.data(), out.mDesc.GetElementSize()); - dumpBufferToFile( - "dump_out_host.bin", out_ref.mData.data(), out_ref.mDesc.GetElementSize()); - if(NeedIndices) - { - dumpBufferToFile("dump_indices.bin", - out_indices.mData.data(), - out_indices.mDesc.GetElementSize()); - dumpBufferToFile("dump_indices_host.bin", - out_indices_ref.mData.data(), - out_indices_ref.mDesc.GetElementSize()); - }; - }; - }; - }; - - std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s" - << std::endl; + if(time_kernel) + std::cout << "Best Perf: " << best_avg_time << " ms, " << best_gb_per_sec << " GB/s" + << std::endl; } else { std::cout << "The requested reduction operation is not supported, please check !!!" << std::endl; }; + + return pass; }; template -void profile_reduce_impl(bool do_verification, +bool profile_reduce_impl(bool do_verification, int init_method, - bool do_log, bool do_dumpout, - int nrepeat, + bool time_kernel, const std::vector& inLengths, const std::vector& reduceDims, ReduceTensorOp ReduceOpId, - NanPropagation NanOpt, - ReduceTensorIndices IndicesOpt, + bool PropagateNan, + bool UseIndex, float alpha, float beta) { bool matched = false; + bool pass = true; using tuple_of_description_instances = tensor_operation::device::device_reduce_instance::reduce_description_instances; @@ -648,29 +470,30 @@ void profile_reduce_impl(bool do_verification, using descType = remove_cvref_t(tuple_object))>; if(!description_match( - descType{}, inLengths.size(), reduceDims, ReduceOpId, NanOpt, IndicesOpt)) + descType{}, inLengths.size(), reduceDims, ReduceOpId, PropagateNan, UseIndex)) return; - profile_reduce_impl_impl(descType::ReduceOpId_), - static_cast(descType::NanOpt_), - static_cast(descType::IndicesOpt_)>( - do_verification, - init_method, - do_log, - do_dumpout, - nrepeat, - inLengths, - reduceDims, - alpha, - beta); + pass = pass && + profile_reduce_impl_impl(descType::ReduceOpId_), + static_cast(descType::PropagateNan_), + static_cast(descType::UseIndex_)>(do_verification, + init_method, + do_dumpout, + time_kernel, + inLengths, + reduceDims, + alpha, + beta); matched = true; }); + + return pass; }; } // namespace profiler diff --git a/profiler/src/profile_batched_gemm.cpp b/profiler/src/profile_batched_gemm.cpp index 2a806b0818..fbdc07c3da 100644 --- a/profiler/src/profile_batched_gemm.cpp +++ b/profiler/src/profile_batched_gemm.cpp @@ -48,8 +48,8 @@ int profile_batched_gemm(int argc, char* argv[]) printf(" 3: A[g, k, m] * B[g, n, k] = C[g, m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); exit(1); } @@ -59,7 +59,7 @@ int profile_batched_gemm(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -82,7 +82,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -102,7 +102,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -122,7 +122,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -142,7 +142,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -162,7 +162,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -182,7 +182,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -202,7 +202,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -222,7 +222,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -242,7 +242,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -262,7 +262,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -282,7 +282,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -302,7 +302,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -322,7 +322,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -342,7 +342,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -362,7 +362,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -382,7 +382,7 @@ int profile_batched_gemm(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -396,5 +396,5 @@ int profile_batched_gemm(int argc, char* argv[]) throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp index 38c3f52193..594fc6bedb 100644 --- a/profiler/src/profile_batched_gemm_reduce.cpp +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -33,8 +33,8 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); printf("arg15: split k into mulitiple batch\n"); exit(1); @@ -45,7 +45,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -69,7 +69,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -91,7 +91,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -113,7 +113,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -135,7 +135,7 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -149,5 +149,5 @@ int profile_batched_gemm_reduce(int argc, char* argv[]) throw std::runtime_error("wrong! this data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_conv_bwd_data.cpp b/profiler/src/profile_conv_bwd_data.cpp deleted file mode 100644 index 2861af3d10..0000000000 --- a/profiler/src/profile_conv_bwd_data.cpp +++ /dev/null @@ -1,195 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "profile_conv_bwd_data_impl.hpp" - -enum struct ConvDataType -{ - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 -}; - -enum struct ConvInputLayout -{ - NCHW, // 0 - NHWC, // 1 -}; - -enum struct ConvWeightLayout -{ - KCYX, // 0 - KYXC, // 1 -}; - -enum struct ConvOutputLayout -{ - NKHW, // 0 - NHWK, // 1 -}; - -int profile_conv_bwd_data(int argc, char* argv[]) -{ - if(argc != 25) - { - printf("arg1: tensor operation (conv_bwd: BackwardConvolution)\n"); - printf("arg2: data type (0: fp32; 1: fp16)\n"); - printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); - printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); - printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); - printf("arg6: verification (0: no; 1: yes)\n"); - printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: run kernel # of times (>1)\n"); - printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - const auto data_type = static_cast(std::stoi(argv[2])); - const auto in_layout = static_cast(std::stoi(argv[3])); - const auto wei_layout = static_cast(std::stoi(argv[4])); - const auto out_layout = static_cast(std::stoi(argv[5])); - const bool do_verification = std::stoi(argv[6]); - const int init_method = std::stoi(argv[7]); - const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); - - const ck::index_t N = std::stoi(argv[10]); - const ck::index_t K = std::stoi(argv[11]); - const ck::index_t C = std::stoi(argv[12]); - const ck::index_t Y = std::stoi(argv[13]); - const ck::index_t X = std::stoi(argv[14]); - const ck::index_t Hi = std::stoi(argv[15]); - const ck::index_t Wi = std::stoi(argv[16]); - - const ck::index_t conv_stride_h = std::stoi(argv[17]); - const ck::index_t conv_stride_w = std::stoi(argv[18]); - const ck::index_t conv_dilation_h = std::stoi(argv[19]); - const ck::index_t conv_dilation_w = std::stoi(argv[20]); - const ck::index_t in_left_pad_h = std::stoi(argv[21]); - const ck::index_t in_left_pad_w = std::stoi(argv[22]); - const ck::index_t in_right_pad_h = std::stoi(argv[23]); - const ck::index_t in_right_pad_w = std::stoi(argv[24]); - - const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; - const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - if(data_type == ConvDataType::F32_F32_F32 && in_layout == ConvInputLayout::NHWC && - wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) - { - ck::profiler::profile_conv_bwd_data_impl<2, - float, - float, - float, - float, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - do_verification, - init_method, - do_log, - nrepeat, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - } - else if(data_type == ConvDataType::F16_F16_F16 && in_layout == ConvInputLayout::NHWC && - wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) - { - ck::profiler::profile_conv_bwd_data_impl<2, - ck::half_t, - ck::half_t, - ck::half_t, - float, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - do_verification, - init_method, - do_log, - nrepeat, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - } - else if(data_type == ConvDataType::BF16_BF16_BF16 && in_layout == ConvInputLayout::NHWC && - wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) - { - ck::profiler::profile_conv_bwd_data_impl<2, - uint16_t, - uint16_t, - uint16_t, - float, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - do_verification, - init_method, - do_log, - nrepeat, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - } - else if(data_type == ConvDataType::INT8_INT8_INT8 && in_layout == ConvInputLayout::NHWC && - wei_layout == ConvWeightLayout::KYXC && out_layout == ConvOutputLayout::NHWK) - { - ck::profiler::profile_conv_bwd_data_impl<2, - int8_t, - int8_t, - int8_t, - int32_t, - ck::tensor_layout::convolution::NHWC, - ck::tensor_layout::convolution::KYXC, - ck::tensor_layout::convolution::NHWK>( - do_verification, - init_method, - do_log, - nrepeat, - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - std::vector{conv_stride_h, conv_stride_w}, - std::vector{conv_dilation_h, conv_dilation_w}, - std::vector{in_left_pad_h, in_left_pad_w}, - std::vector{in_right_pad_h, in_right_pad_w}); - } - else - { - throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); - } - - return 1; -} diff --git a/profiler/src/profile_conv_bwd_weight.cpp b/profiler/src/profile_conv_bwd_weight.cpp index 309cc8ea2c..80413322b3 100644 --- a/profiler/src/profile_conv_bwd_weight.cpp +++ b/profiler/src/profile_conv_bwd_weight.cpp @@ -58,7 +58,7 @@ int profile_conv_bwd_weight(int argc, char* argv[]) const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); + const bool time_kernel = std::stoi(argv[9]); const ck::index_t N = std::stoi(argv[10]); const ck::index_t K = std::stoi(argv[11]); @@ -98,7 +98,7 @@ int profile_conv_bwd_weight(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, N, K, C, @@ -124,7 +124,7 @@ int profile_conv_bwd_weight(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, N, K, C, @@ -142,5 +142,5 @@ int profile_conv_bwd_weight(int argc, char* argv[]) throw std::runtime_error("wrong! this Conv data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp index 1c447b483e..ca7dc1935a 100644 --- a/profiler/src/profile_conv_fwd_bias_relu.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu.cpp @@ -42,7 +42,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) printf("arg6: verification (0: no; 1: yes)\n"); printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: run kernel # of times (>1)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(1); @@ -55,7 +55,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); + const bool time_kernel = std::stoi(argv[9]); const ck::index_t N = std::stoi(argv[10]); const ck::index_t K = std::stoi(argv[11]); @@ -93,7 +93,7 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, N, K, C, @@ -110,5 +110,5 @@ int profile_conv_fwd_bias_relu(int argc, char* argv[]) throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp index 522487c77b..5d75f5a294 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_add.cpp @@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) printf("arg6: verification (0: no; 1: yes)\n"); printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: run kernel # of times (>1)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(1); @@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); + const bool time_kernel = std::stoi(argv[9]); const ck::index_t N = std::stoi(argv[10]); const ck::index_t K = std::stoi(argv[11]); @@ -94,7 +94,7 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, N, K, C, @@ -111,5 +111,5 @@ int profile_conv_fwd_bias_relu_add(int argc, char* argv[]) throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp index 833f2851db..96d3b10ddf 100644 --- a/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp +++ b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp @@ -43,7 +43,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) printf("arg6: verification (0: no; 1: yes)\n"); printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: run kernel # of times (>1)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); exit(1); @@ -56,7 +56,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); + const bool time_kernel = std::stoi(argv[9]); const ck::index_t N = std::stoi(argv[10]); const ck::index_t K = std::stoi(argv[11]); @@ -95,7 +95,7 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, N, K, C, @@ -112,5 +112,5 @@ int profile_conv_fwd_bias_relu_atomic_add(int argc, char* argv[]) throw std::runtime_error("wrong! data_type & layout for this operator is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_convnd_bwd_data.cpp b/profiler/src/profile_convnd_bwd_data.cpp index 893fb8c791..5d0e6a34c7 100644 --- a/profiler/src/profile_convnd_bwd_data.cpp +++ b/profiler/src/profile_convnd_bwd_data.cpp @@ -39,40 +39,40 @@ ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) ck::utils::conv::ConvParams params; - params.num_dim_spatial = num_dim_spatial; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial_ = num_dim_spatial; + params.N_ = std::stoi(argv[arg_idx++]); + params.K_ = std::stoi(argv[arg_idx++]); + params.C_ = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(num_dim_spatial); + params.filter_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.filter_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(num_dim_spatial); + params.input_spatial_lengths_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + params.input_spatial_lengths_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(num_dim_spatial); + params.conv_filter_strides_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_strides_[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(num_dim_spatial); + params.conv_filter_dilations_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + params.conv_filter_dilations_[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(num_dim_spatial); + params.input_left_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + params.input_left_pads_[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(num_dim_spatial); + params.input_right_pads_.resize(num_dim_spatial); for(int i = 0; i < num_dim_spatial; ++i) { - params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + params.input_right_pads_[i] = std::stoi(argv[arg_idx++]); } return params; @@ -95,7 +95,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) printf("arg6: verification (0: no; 1: yes)\n"); printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg9: run kernel # of times (>1)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " "RightPx\n"); return 1; @@ -108,7 +108,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) const bool do_verification = std::stoi(argv[6]); const int init_method = std::stoi(argv[7]); const bool do_log = std::stoi(argv[8]); - const int nrepeat = std::stoi(argv[9]); + const bool time_kernel = std::stoi(argv[9]); ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); @@ -132,17 +132,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) do_verification, init_method, do_log, - nrepeat, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads); + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); break; case 2: @@ -157,17 +157,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) do_verification, init_method, do_log, - nrepeat, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads); + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); break; case 3: @@ -182,17 +182,17 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) do_verification, init_method, do_log, - nrepeat, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + time_kernel, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads); + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_); break; default: break; diff --git a/profiler/src/profile_convnd_fwd.cpp b/profiler/src/profile_convnd_fwd.cpp index 1abd73c729..cb92587897 100644 --- a/profiler/src/profile_convnd_fwd.cpp +++ b/profiler/src/profile_convnd_fwd.cpp @@ -1,11 +1,12 @@ #include +#include #include #include #include #include #include -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "element_wise_operation.hpp" #include "fill.hpp" #include "profile_convnd_fwd.hpp" @@ -119,7 +120,7 @@ template , - ck::utils::FillUniform>>( - params, true, ck::utils::FillUniform{}, ck::utils::FillUniform{}); + ck::utils::FillUniformDistributionIntegerValue, + ck::utils::FillUniformDistributionIntegerValue>>( + params, + true, + ck::utils::FillUniformDistributionIntegerValue{}, + ck::utils::FillUniformDistributionIntegerValue{}); break; case 2: conv_instance = std::make_unique< @@ -165,12 +169,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::utils::FillUniform, - ck::utils::FillUniform>>( + ck::utils::FillUniformDistribution, + ck::utils::FillUniformDistribution>>( params, true, - ck::utils::FillUniform{}, - ck::utils::FillUniform{}); + ck::utils::FillUniformDistribution{}, + ck::utils::FillUniformDistribution{}); break; default: throw std::runtime_error("Unsupported init method!"); } @@ -181,11 +185,13 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, _1, _2, _3); - OpInstanceRunEngine run_engine(*conv_instance, - reference_conv_fwd_fun); + + OpInstanceRunEngine run_engine( + *conv_instance, reference_conv_fwd_fun, do_verification); + auto best_conf = run_engine.Profile( conv::ConvolutionFwdInstances::template Get(), - nrepeat, + time_kernel, do_verification, do_log); @@ -201,7 +207,7 @@ void profile_convnd_instances(ConvDataType data_type, const ck::utils::conv::ConvParams& params, bool do_verification, bool do_log, - int nrepeat, + bool time_kernel, int init_method) { switch(data_layout) @@ -214,7 +220,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -223,7 +229,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -232,7 +238,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -241,7 +247,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -256,7 +262,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -265,7 +271,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -274,7 +280,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -283,7 +289,7 @@ void profile_convnd_instances(ConvDataType data_type, params, do_verification, do_log, - nrepeat, + time_kernel, init_method, ConvolutionLayouts{}); break; @@ -304,7 +310,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) bool do_verification{true}; int init_method{2}; bool do_log{false}; - int nrepeat{100}; + bool time_kernel{false}; int num_dim_spatial{2}; ConvParams params; @@ -318,7 +324,7 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) do_verification = std::stoi(argv[4]); init_method = std::stoi(argv[5]); do_log = std::stoi(argv[6]); - nrepeat = std::stoi(argv[7]); + time_kernel = std::stoi(argv[7]); num_dim_spatial = std::stoi(argv[8]); } if(argc >= 10) @@ -332,20 +338,20 @@ int ck::profiler::profile_convnd_fwd(int argc, char* argv[]) { case 1: profile_convnd_instances<1>( - data_type, data_layout, params, do_verification, do_log, nrepeat, init_method); + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); break; case 2: profile_convnd_instances<2>( - data_type, data_layout, params, do_verification, do_log, nrepeat, init_method); + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); break; case 3: profile_convnd_instances<3>( - data_type, data_layout, params, do_verification, do_log, nrepeat, init_method); + data_type, data_layout, params, do_verification, do_log, time_kernel, init_method); break; default: throw std::runtime_error("profile_conv_fwd: unsupported num_dim_spatial value: " + std::to_string(num_dim_spatial)); } - return 1; + return 0; } diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 7a72be2d8e..0684e18322 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -38,8 +38,8 @@ int profile_gemm(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg14: split k into mulitiple batch\n"); exit(1); @@ -50,7 +50,7 @@ int profile_gemm(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -68,13 +68,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -88,13 +89,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -108,13 +110,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -128,13 +131,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -146,6 +150,7 @@ int profile_gemm(int argc, char* argv[]) else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -248,13 +257,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -268,13 +278,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -288,13 +299,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -308,13 +320,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -328,13 +341,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -348,13 +362,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -368,13 +383,14 @@ int profile_gemm(int argc, char* argv[]) ck::profiler::profile_gemm_impl( do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -388,5 +404,5 @@ int profile_gemm(int argc, char* argv[]) throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_gemm_add_add_fastgelu.cpp b/profiler/src/profile_gemm_add_add_fastgelu.cpp new file mode 100644 index 0000000000..602f14a78a --- /dev/null +++ b/profiler/src/profile_gemm_add_add_fastgelu.cpp @@ -0,0 +1,152 @@ +#include +#include +#include +#include +#include + +#include "profile_gemm_add_add_fastgelu_impl.hpp" + +int profile_gemm_add_add_fastgelu(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + KM_KN_MN_MN_MN, // 2 + KM_NK_MN_MN_MN, // 3 + MK_KN_NM_MN_MN, // 4 + MK_NK_NM_MN_MN, // 5 + KM_KN_NM_MN_MN, // 6 + KM_NK_NM_MN_MN, // 7 + }; + + enum struct MatrixDataType + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16, // 1 + BF16_BF16_BF16_BF16_BF16, // 2 + INT8_INT8_INT8_INT8_INT8, // 3 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+GeLU)\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); + printf(" 2: E[m, n] = FastGeLU(A[k, m] * B[k, n] + D0[m, n] + D1[m, n]);\n"); + printf(" 3: E[m, n] = FastGeLU(A[k, m] * B[n, k] + D0[m, n] + D1[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideE = std::stoi(argv[15]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto d1_layout, + auto e_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EDataType = decltype(e_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using D1Layout = decltype(d1_layout); + using ELayout = decltype(e_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideE = ck::is_same_v ? N : M; + + return ck::profiler::profile_gemm_add_add_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideE < 0) ? DefaultStrideE : StrideE); + }; + + if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_KN_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && + layout == MatrixLayout::KM_NK_MN_MN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 0; + } +} diff --git a/profiler/src/profile_gemm_bias_2d.cpp b/profiler/src/profile_gemm_bias_2d.cpp index dd7e418087..51dba85f32 100644 --- a/profiler/src/profile_gemm_bias_2d.cpp +++ b/profiler/src/profile_gemm_bias_2d.cpp @@ -36,8 +36,8 @@ int profile_gemm_bias_2d(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg14: alpha\n"); printf("arg15: beta\n"); @@ -50,7 +50,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -76,7 +76,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -99,7 +99,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -122,7 +122,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -145,7 +145,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -168,7 +168,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -191,7 +191,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -214,7 +214,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -237,7 +237,7 @@ int profile_gemm_bias_2d(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -252,5 +252,5 @@ int profile_gemm_bias_2d(int argc, char* argv[]) throw std::runtime_error("wrong! this data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_gemm_bias_add_reduce.cpp b/profiler/src/profile_gemm_bias_add_reduce.cpp new file mode 100644 index 0000000000..d36e5f1c83 --- /dev/null +++ b/profiler/src/profile_gemm_bias_add_reduce.cpp @@ -0,0 +1,159 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_add_reduce_impl.hpp" + +int profile_gemm_bias_add_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType + { + F32_F32_F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+bias+add+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmReduceDataType::F16_F16_F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_add_reduce_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 0; +} diff --git a/profiler/src/profile_gemm_bias_relu.cpp b/profiler/src/profile_gemm_bias_relu.cpp index 67a47cf9ec..bf035d9ad9 100644 --- a/profiler/src/profile_gemm_bias_relu.cpp +++ b/profiler/src/profile_gemm_bias_relu.cpp @@ -36,8 +36,8 @@ int profile_gemm_bias_relu(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg14: split k into mulitiple batch\n"); exit(1); @@ -48,7 +48,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -69,7 +69,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -88,7 +88,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -107,7 +107,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -126,7 +126,7 @@ int profile_gemm_bias_relu(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -139,5 +139,5 @@ int profile_gemm_bias_relu(int argc, char* argv[]) throw std::runtime_error("wrong! this data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_gemm_bias_relu_add.cpp b/profiler/src/profile_gemm_bias_relu_add.cpp index 52406e93d6..9c324f6cf9 100644 --- a/profiler/src/profile_gemm_bias_relu_add.cpp +++ b/profiler/src/profile_gemm_bias_relu_add.cpp @@ -36,8 +36,8 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); printf("arg15: split k into mulitiple batch\n"); exit(1); @@ -48,7 +48,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -70,7 +70,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -90,7 +90,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -110,7 +110,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -130,7 +130,7 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -144,5 +144,5 @@ int profile_gemm_bias_relu_add(int argc, char* argv[]) throw std::runtime_error("wrong! this data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_gemm_reduce.cpp b/profiler/src/profile_gemm_reduce.cpp index a83d4ce9a1..a23967acd7 100644 --- a/profiler/src/profile_gemm_reduce.cpp +++ b/profiler/src/profile_gemm_reduce.cpp @@ -32,8 +32,8 @@ int profile_gemm_reduce(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); printf("arg14: split k into mulitiple batch\n"); exit(1); @@ -44,7 +44,7 @@ int profile_gemm_reduce(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const int M = std::stoi(argv[8]); const int N = std::stoi(argv[9]); @@ -66,7 +66,7 @@ int profile_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -87,7 +87,7 @@ int profile_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -108,7 +108,7 @@ int profile_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -129,7 +129,7 @@ int profile_gemm_reduce(int argc, char* argv[]) do_verification, init_method, do_log, - nrepeat, + time_kernel, M, N, K, @@ -142,5 +142,5 @@ int profile_gemm_reduce(int argc, char* argv[]) throw std::runtime_error("wrong! this data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 88a2a8f855..ea73d446e3 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -54,8 +54,8 @@ int profile_grouped_gemm(int argc, char* argv[]) printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); printf("arg4: verification (0: no; 1: yes)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); - printf("arg8: print tensor value (0: no; 1: yes)\n"); - printf("arg7: run kernel # of times (>1)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " "64,64 64,64 128,128)\n"); exit(1); @@ -66,7 +66,7 @@ int profile_grouped_gemm(int argc, char* argv[]) const bool do_verification = std::stoi(argv[4]); const int init_method = std::stoi(argv[5]); const bool do_log = std::stoi(argv[6]); - const int nrepeat = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[7]); const auto Ms = argToIntArray(argv[8]); const auto Ns = argToIntArray(argv[9]); @@ -79,6 +79,7 @@ int profile_grouped_gemm(int argc, char* argv[]) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { ck::profiler::profile_grouped_gemm_impl(do_verification, init_method, do_log, - nrepeat, + time_kernel, Ms, Ns, Ks, @@ -97,6 +98,7 @@ int profile_grouped_gemm(int argc, char* argv[]) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) { ck::profiler::profile_grouped_gemm_impl(do_verification, init_method, do_log, - nrepeat, + time_kernel, Ms, Ns, Ks, @@ -115,6 +117,7 @@ int profile_grouped_gemm(int argc, char* argv[]) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) { ck::profiler::profile_grouped_gemm_impl(do_verification, init_method, do_log, - nrepeat, + time_kernel, Ms, Ns, Ks, @@ -133,6 +136,7 @@ int profile_grouped_gemm(int argc, char* argv[]) else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) { ck::profiler::profile_grouped_gemm_impl(do_verification, init_method, do_log, - nrepeat, + time_kernel, Ms, Ns, Ks, @@ -153,5 +157,5 @@ int profile_grouped_gemm(int argc, char* argv[]) throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); } - return 1; + return 0; } diff --git a/profiler/src/profile_reduce.cpp b/profiler/src/profile_reduce.cpp index c6dea1e385..bdbac4fab4 100644 --- a/profiler/src/profile_reduce.cpp +++ b/profiler/src/profile_reduce.cpp @@ -1,27 +1,19 @@ #include #include -#include -#include #include #include #include #include #include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" +#include "data_type_enum.hpp" #include "reduction_enums.hpp" +#include "host_common_util.hpp" #include "profile_reduce_impl.hpp" using namespace std; -using ck::NanPropagation; -using ck::ReduceTensorIndices; using ck::ReduceTensorOp; static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, @@ -38,63 +30,9 @@ static struct option long_options[] = {{"inLengths", required_argument, nullptr, {"bf16", no_argument, nullptr, '?'}, {"dumpout", required_argument, nullptr, 'o'}, {"verify", required_argument, nullptr, 'v'}, - {"log", required_argument, nullptr, 'l'}, {"help", no_argument, nullptr, '?'}, {nullptr, 0, nullptr, 0}}; -template -static T getSingleValueFromString(const string& valueStr) -{ - std::istringstream iss(valueStr); - - T val; - - iss >> val; - - return (val); -}; - -template -static std::vector getTypeValuesFromString(const char* cstr_values) -{ - std::string valuesStr(cstr_values); - - std::vector values; - std::size_t pos = 0; - std::size_t new_pos; - - new_pos = valuesStr.find(',', pos); - while(new_pos != std::string::npos) - { - const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); - - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - pos = new_pos + 1; - new_pos = valuesStr.find(',', pos); - }; - - std::string sliceStr = valuesStr.substr(pos); - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - return (values); -} - -enum struct AppDataType -{ - appHalf = 0, - appFloat = 1, - appInt32 = 2, - appInt8 = 3, - appInt8x4 = 4, - appBFloat16 = 5, - appDouble = 6, -}; - static void check_reduce_dims(const int rank, const std::vector& reduceDims) { for(auto dim : reduceDims) @@ -113,7 +51,7 @@ static void check_reduce_dims(const int rank, const std::vector& reduceDims }; }; -class AppArgs +class ReduceProfilerArgs { private: int option_index = 0; @@ -130,26 +68,23 @@ class AppArgs std::vector scales; - ReduceTensorOp reduceOp = ReduceTensorOp::ADD; - AppDataType compTypeId = AppDataType::appFloat; - AppDataType outTypeId = AppDataType::appFloat; + ReduceTensorOp reduceOp = ReduceTensorOp::ADD; + ck::DataTypeEnum compTypeId = ck::DataTypeEnum::Float; + ck::DataTypeEnum outTypeId = ck::DataTypeEnum::Float; bool compType_assigned = false; bool outType_assigned = false; - NanPropagation nanOpt = NanPropagation::NOT_PROPAGATE_NAN; - ReduceTensorIndices indicesOpt = ReduceTensorIndices::NO_INDICES; - bool do_log = false; - bool do_verification = false; - bool do_dumpout = false; + int nanOpt = 0; + int indicesOpt = 0; + bool do_verification = false; + bool do_dumpout = false; int init_method; - int nrepeat; + bool time_kernel; - bool need_indices = false; - - AppArgs() = default; - ~AppArgs() = default; + ReduceProfilerArgs() = default; + ~ReduceProfilerArgs() = default; void show_usage(const char* cmd) { @@ -166,8 +101,11 @@ class AppArgs std::cout << "--outType or -W, optional enum value indicating the type of the reduced " "output, which could be float when the input data is half" << std::endl; - std::cout << "--nanOpt or -N, enum value indicates the selection for NanOpt" << std::endl; - std::cout << "--indicesOpt or -I, enum value indicates the selection for IndicesOpt" + std::cout + << "--nanOpt or -N, 1/0 value indicates the selection to use or not use Nan-Propagation" + << std::endl; + std::cout << "--indicesOpt or -I, 1/0 value indicates the selection to use or not use " + "index in reduction" << std::endl; std::cout << "--scales or -S, comma separated two float values for alpha and beta" << std::endl; @@ -181,18 +119,19 @@ class AppArgs std::cout << "--dumpout or -o, 1/0 to indicate where to save the reduction result to files " "for further analysis" << std::endl; - std::cout << "--log or -l, 1/0 to indicate whether to log some information" << std::endl; }; int processArgs(int argc, char* argv[]) { - unsigned int ch; + using ck::host_common::getTypeValuesFromString; + + int ch; optind++; // to skip the "reduce" module name while(1) { - ch = getopt_long(argc, argv, "D:R:O:C:W:N:I:S:v:o:l:", long_options, &option_index); + ch = getopt_long(argc, argv, "D:R:O:C:W:N:I:S:v:o:", long_options, &option_index); if(ch == -1) break; switch(ch) @@ -219,27 +158,27 @@ class AppArgs if(!optarg) throw std::runtime_error("Invalid option format!"); - compTypeId = static_cast(std::atoi(optarg)); + compTypeId = static_cast(std::atoi(optarg)); compType_assigned = true; break; case 'W': if(!optarg) throw std::runtime_error("Invalid option format!"); - outTypeId = static_cast(std::atoi(optarg)); + outTypeId = static_cast(std::atoi(optarg)); outType_assigned = true; break; case 'N': if(!optarg) throw std::runtime_error("Invalid option format!"); - nanOpt = static_cast(std::atoi(optarg)); + nanOpt = std::atoi(optarg); break; case 'I': if(!optarg) throw std::runtime_error("Invalid option format!"); - indicesOpt = static_cast(std::atoi(optarg)); + indicesOpt = std::atoi(optarg); break; case 'S': if(!optarg) @@ -262,12 +201,6 @@ class AppArgs do_dumpout = static_cast(std::atoi(optarg)); break; - case 'l': - if(!optarg) - throw std::runtime_error("Invalid option format!"); - - do_log = static_cast(std::atoi(optarg)); - break; case '?': if(std::string(long_options[option_index].name) == "half") use_half = true; @@ -295,7 +228,7 @@ class AppArgs throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!"); init_method = std::atoi(argv[optind++]); - nrepeat = std::atoi(argv[optind]); + time_kernel = static_cast(std::atoi(argv[optind])); if(scales.empty()) { @@ -306,9 +239,6 @@ class AppArgs if(reduceOp == ReduceTensorOp::MIN || reduceOp == ReduceTensorOp::MAX || reduceOp == ReduceTensorOp::AMAX) { - if(indicesOpt != ReduceTensorIndices::NO_INDICES) - need_indices = true; - // for indexable operations, no need to assign compType and outType, just let them be // same as inType compType_assigned = false; @@ -322,9 +252,10 @@ class AppArgs int profile_reduce(int argc, char* argv[]) { - using namespace ck::profiler; + using ck::DataTypeEnum; + using ck::profiler::profile_reduce_impl; - AppArgs args; + ReduceProfilerArgs args; if(args.processArgs(argc, argv) < 0) return (-1); @@ -339,42 +270,41 @@ int profile_reduce(int argc, char* argv[]) if(args.use_half) { if(!args.compType_assigned) - args.compTypeId = AppDataType::appHalf; + args.compTypeId = DataTypeEnum::Half; if(args.outType_assigned && - (args.outTypeId != AppDataType::appHalf && args.outTypeId != AppDataType::appFloat)) - args.outTypeId = AppDataType::appFloat; + (args.outTypeId != DataTypeEnum::Half && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; if(!args.outType_assigned) - args.outTypeId = AppDataType::appHalf; + args.outTypeId = DataTypeEnum::Half; - if(args.compTypeId == AppDataType::appHalf) + if(args.compTypeId == DataTypeEnum::Half) { - profile_reduce_impl(args.do_verification, - args.init_method, - args.do_log, - args.do_dumpout, - args.nrepeat, - args.inLengths, - args.reduceDims, - args.reduceOp, - args.nanOpt, - args.indicesOpt, - args.scales[0], - args.scales[1]); + profile_reduce_impl( + args.do_verification, + args.init_method, + args.do_dumpout, + args.time_kernel, + args.inLengths, + args.reduceDims, + args.reduceOp, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), + args.scales[0], + args.scales[1]); } - else if(args.compTypeId == AppDataType::appFloat) + else if(args.compTypeId == DataTypeEnum::Float) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } @@ -385,56 +315,53 @@ int profile_reduce(int argc, char* argv[]) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } else if(args.use_int8) { if(!args.compType_assigned) - args.compTypeId = AppDataType::appInt8; + args.compTypeId = DataTypeEnum::Int8; if(args.outType_assigned && - (args.outTypeId != AppDataType::appInt8 && args.outTypeId != AppDataType::appInt32)) - args.outTypeId = AppDataType::appInt32; + (args.outTypeId != DataTypeEnum::Int8 && args.outTypeId != DataTypeEnum::Int32)) + args.outTypeId = DataTypeEnum::Int32; if(!args.outType_assigned) - args.outTypeId = AppDataType::appInt8; + args.outTypeId = DataTypeEnum::Int8; - if(args.compTypeId == AppDataType::appInt8) + if(args.compTypeId == DataTypeEnum::Int8) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } - else if(args.compTypeId == AppDataType::appInt32) + else if(args.compTypeId == DataTypeEnum::Int32) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } @@ -444,54 +371,51 @@ int profile_reduce(int argc, char* argv[]) else if(args.use_bf16) { if(args.outType_assigned && - (args.outTypeId != AppDataType::appBFloat16 && args.outTypeId != AppDataType::appFloat)) - args.outTypeId = AppDataType::appFloat; + (args.outTypeId != DataTypeEnum::BFloat16 && args.outTypeId != DataTypeEnum::Float)) + args.outTypeId = DataTypeEnum::Float; if(!args.outType_assigned) - args.outTypeId = AppDataType::appBFloat16; + args.outTypeId = DataTypeEnum::BFloat16; profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } else { - if(args.compTypeId == AppDataType::appFloat) + if(args.compTypeId == DataTypeEnum::Float) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } - else if(args.compTypeId == AppDataType::appDouble) + else if(args.compTypeId == DataTypeEnum::Double) { profile_reduce_impl(args.do_verification, args.init_method, - args.do_log, args.do_dumpout, - args.nrepeat, + args.time_kernel, args.inLengths, args.reduceDims, args.reduceOp, - args.nanOpt, - args.indicesOpt, + static_cast(args.nanOpt), + static_cast(args.indicesOpt), args.scales[0], args.scales[1]); } diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index 2a8078ca5f..ceaebf2c7c 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -11,8 +11,10 @@ int profile_gemm_bias_2d(int, char*[]); int profile_gemm_bias_relu(int, char*[]); int profile_gemm_bias_relu_add(int, char*[]); int profile_gemm_reduce(int, char*[]); +int profile_gemm_bias_add_reduce(int, char*[]); int profile_batched_gemm(int, char*[]); int profile_grouped_gemm(int, char*[]); +int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); @@ -20,9 +22,39 @@ int profile_convnd_bwd_data(int, char*[], int); int profile_reduce(int, char*[]); int profile_conv_bwd_weight(int, char*[]); int profile_batched_gemm_reduce(int, char*[]); +int profile_gemm_add_add_fastgelu(int, char*[]); + +static void print_helper_message() +{ + // clang-format off + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_bias_2d: GEMM+Bias(2D)\n" + " gemm_bias_relu: GEMM+Bias+ReLU\n" + " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" + " gemm_reduce: GEMM+Reduce\n" + " grouped_gemm: Grouped GEMM\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" + " conv1d_bwd_data: BackwardConvolution data 1 dim\n" + " conv2d_bwd_data: BackwardConvolution data 2 dim\n" + " conv3d_bwd_data: BackwardConvolution data 3 dim\n" + " reduce: Reduce\n" + " conv2d_bwd_weight: Backward Weight Convolution 2d\n" + " gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU\n"); + // clang-format on +} int main(int argc, char* argv[]) { + if(argc == 1) + { + print_helper_message(); + + return 0; + } + if(strcmp(argv[1], "gemm") == 0) { return profile_gemm(argc, argv); @@ -43,6 +75,10 @@ int main(int argc, char* argv[]) { return profile_gemm_reduce(argc, argv); } + else if(strcmp(argv[1], "gemm_bias_add_reduce") == 0) + { + return profile_gemm_bias_add_reduce(argc, argv); + } else if(strcmp(argv[1], "batched_gemm") == 0) { return profile_batched_gemm(argc, argv); @@ -53,7 +89,7 @@ int main(int argc, char* argv[]) } else if(strcmp(argv[1], "grouped_gemm") == 0) { - profile_grouped_gemm(argc, argv); + return profile_grouped_gemm(argc, argv); } else if(strcmp(argv[1], "conv_fwd") == 0) { @@ -91,25 +127,14 @@ int main(int argc, char* argv[]) { return profile_conv_bwd_weight(argc, argv); } + else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0) + { + return profile_gemm_add_add_fastgelu(argc, argv); + } else { - // clang-format off - printf("arg1: tensor operation (gemm: GEMM\n" - " gemm_bias_2d: GEMM+Bias(2D)\n" - " gemm_bias_relu: GEMM+Bias+ReLU\n" - " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" - " gemm_reduce: GEMM+Reduce\n" - " grouped_gemm: Grouped GEMM\n" - " conv_fwd: ForwardConvolution\n" - " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" - " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" - " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n" - " conv1d_bwd_data: BackwardConvolution data 1 dim\n" - " conv2d_bwd_data: BackwardConvolution data 2 dim\n" - " conv3d_bwd_data: BackwardConvolution data 3 dim\n" - " reduce: REDUCE\n" - " conv2d_bwd_weight: Backward Weight Convolution 2d\n"); - // clang-format on + print_helper_message(); + + return 0; } - return 0; } diff --git a/script/parse_perf_data.py b/script/parse_perf_data.py new file mode 100644 index 0000000000..4cb13e6243 --- /dev/null +++ b/script/parse_perf_data.py @@ -0,0 +1,290 @@ +#!/usr/bin/env python3 +import os, io, argparse, datetime, re +import numpy as np +import sqlalchemy +from sqlalchemy.types import NVARCHAR, Float, Integer +import pymysql +import pandas as pd +from sshtunnel import SSHTunnelForwarder + +def print_to_string(*args, **kwargs): + output = io.StringIO() + print(*args, file=output, **kwargs) + contents = output.getvalue() + output.close() + return contents + +def parse_args(): + parser = argparse.ArgumentParser(description='Parse results from tf benchmark runs') + parser.add_argument('filename', type=str, help='Log file to prase or directory containing log files') + args = parser.parse_args() + files = [] + if os.path.isdir(args.filename): + all_files = os.listdir(args.filename) + for name in all_files: + if not 'log' in name: + continue + files.append(os.path.join(args.filename, name)) + else: + files = [args.filename] + args.files = files + return args + +def main(): + args = parse_args() + tests = [] + kernels=[] + tflops=[] + dtype=[] + alayout=[] + blayout=[] + M=[] + N=[] + K=[] + StrideA=[] + StrideB=[] + StrideC=[] + #parse results, get the Tflops value for "Best Perf" kernels + + glue="" + for filename in args.files: + for line in open(filename): + if 'Branch name' in line: + lst=line.split() + branch_name=lst[2] + if 'On branch' in line: + lst=line.split() + branch_name=lst[2] + if 'Node name' in line: + lst=line.split() + node_id=lst[2] + if 'GPU_arch' in line: + lst=line.split() + gpu_arch=lst[2] + if 'HIP version' in line: + lst=line.split() + hip_vers=lst[2] + if 'Compute Unit' in line: + lst=line.split() + compute_units=lst[2] + if 'InstalledDir' in line: + lst=line.split() + rocm_vers=lst[1][lst[1].find('/opt/rocm-')+len('/opt/rocm-'):lst[1].rfind('/llvm/bin')] + print("Branch name:",branch_name) + print("Node name:",node_id) + print("GPU_arch:",gpu_arch) + print("Compute units:",compute_units) + print("ROCM_version:",rocm_vers) + print("HIP_version:",hip_vers) + + + #parse gemm performance tests: + if 'gemm' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + if len(lst)>=37: #the line is complete + tests.append(glue.join(lst[5:30])) + kernels.append(glue.join(lst[37:])) + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + elif len(lst)<37 and len(lst)>=33: #the tflops are available + tests.append(glue.join(lst[5:30])) + kernels.append("N/A") + tflops.append(lst[33]) + dtype.append(lst[5]) + alayout.append(lst[8]) + blayout.append(lst[11]) + M.append(lst[14]) + N.append(lst[17]) + K.append(lst[20]) + StrideA.append(lst[23]) + StrideB.append(lst[26]) + StrideC.append(lst[29]) + print("warning: incomplete line:",lst) + elif len(lst)<33: #even the tflops are not available + print("Error in ckProfiler output!") + print("warning: incomplete line=",lst) + #sort results + #sorted_tests = sorted(tests) + #print("sorted tests:",sorted_tests) + sorted_tflops = [x for _,x in sorted(zip(tests,tflops))] + #sorted_kernels = [x for _,x in sorted(zip(tests,kernels))] + test_list=list(range(1,len(tests)+1)) + + #parse resnet50 performance tests: + if 'resnet50' in filename: + for filename in args.files: + for line in open(filename): + if 'Best Perf' in line: + lst=line.split() + tflops.append(lst[4]) + + print("Number of tests:",len(tflops)) + sql_hostname = '127.0.0.1' + sql_username = os.environ["dbuser"] + sql_password = os.environ["dbpassword"] + sql_main_database = 'miopen_perf' + sql_port = 3306 + ssh_host = os.environ["dbsship"] + ssh_user = os.environ["dbsshuser"] + ssh_port = int(os.environ["dbsshport"]) + ssh_pass = os.environ["dbsshpassword"] + + with SSHTunnelForwarder( + (ssh_host, ssh_port), + ssh_username=ssh_user, + ssh_password=ssh_pass, + remote_bind_address=(sql_hostname, sql_port)) as tunnel: + + sqlEngine = sqlalchemy.create_engine('mysql+pymysql://{0}:{1}@{2}:{3}/{4}'. + format(sql_username, sql_password, sql_hostname, tunnel.local_bind_port, sql_main_database)) + conn = sqlEngine.connect() + + #save gemm performance tests: + if 'gemm' in filename: + + #write the ck_gemm_test_params table + #only needed once the test set changes + ''' + sorted_dtypes = [x for _,x in sorted(zip(tests,dtype))] + sorted_alayout = [x for _,x in sorted(zip(tests,alayout))] + sorted_blayout = [x for _,x in sorted(zip(tests,blayout))] + sorted_M = [x for _,x in sorted(zip(tests,M))] + sorted_N = [x for _,x in sorted(zip(tests,N))] + sorted_K = [x for _,x in sorted(zip(tests,K))] + sorted_StrideA = [x for _,x in sorted(zip(tests,StrideA))] + sorted_StrideB = [x for _,x in sorted(zip(tests,StrideB))] + sorted_StrideC = [x for _,x in sorted(zip(tests,StrideC))] + ck_gemm_params=[test_list,sorted_dtypes,sorted_alayout,sorted_blayout, + sorted_M,sorted_N,sorted_K,sorted_StrideA,sorted_StrideB, + sorted_StrideC] + df=pd.DataFrame(np.transpose(ck_gemm_params),columns=['Test_number','Data_type', + 'Alayout','BLayout','M','N','K', 'StrideA','StrideB','StrideC']) + print(df) + + dtypes = { + 'Test_number': Integer(), + 'Data_type': NVARCHAR(length=5), + 'Alayout': NVARCHAR(length=12), + 'Blayout': NVARCHAR(length=12), + 'M': Integer(), + 'N': Integer(), + 'K': Integer(), + 'StrideA': Integer(), + 'StrideB': Integer(), + 'StrideC': Integer() + } + df.to_sql("ck_gemm_test_params",conn,if_exists='replace',index=False, dtype=dtypes) + ''' + + #read baseline results for the latest develop branch + query = '''SELECT * from ck_gemm_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_gemm_tflops where Branch_ID='develop' );''' + tflops_base = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,len(tests)+1): + testlist.append("Test%i"%i) + ck_gemm_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops=pd.DataFrame(data=[ck_gemm_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[sorted_tflops],columns=testlist) + flops=pd.concat([flops,df_add],axis=1) + print("new tflops for gemm tests:",flops) + flops.to_sql("ck_gemm_tflops",conn,if_exists='append',index=False) + + #save resnet50 performance tests: + if 'resnet50' in filename: + #read baseline results for the latest develop branch + query = '''SELECT * from ck_resnet50_N256_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N256_tflops where Branch_ID='develop' );''' + tflops_base_N256 = pd.read_sql_query(query, conn) + query = '''SELECT * from ck_resnet50_N4_tflops WHERE Datetime = (SELECT MAX(Datetime) FROM ck_resnet50_N4_tflops where Branch_ID='develop' );''' + tflops_base_N4 = pd.read_sql_query(query, conn) + + #write new results to the db + testlist=[] + for i in range(1,50): + testlist.append("Layer%i"%i) + ck_resnet_tflops=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(datetime.datetime.now())] + flops0=pd.DataFrame(data=[ck_resnet_tflops],columns=['Branch_ID','Node_ID','GPU_arch','Compute Units','ROCM_version','HIP_version','Datetime']) + df_add=pd.DataFrame(data=[tflops[0:49]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=256 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N256_tflops",conn,if_exists='append',index=False) + df_add=pd.DataFrame(data=[tflops[49:98]],columns=testlist) + flops=pd.concat([flops0,df_add],axis=1) + print("new tflops for N=4 resnet50 test:",flops) + flops.to_sql("ck_resnet50_N4_tflops",conn,if_exists='append',index=False) + + conn.close() + + #compare the results to the baseline if baseline exists + regression=0 + if 'gemm' in filename: + if not tflops_base.empty: + base=tflops_base[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(sorted_tflops[i]): + print("test # ",i,"shows regression by {:.3f}%".format( + (float(sorted_tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(sorted_tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline") + if 'resnet50' in filename: + if not tflops_base_N256.empty: + base=tflops_base_N256[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=256") + if not tflops_base_N4.empty: + base=tflops_base_N4[testlist].to_numpy(dtype='float') + base_list=base[0] + ave_perf=0 + for i in range(len(base_list)): + # success criterion: + if base_list[i]>1.01*float(tflops[i+49]): + print("layer # ",i,"shows regression by {:.3f}%".format( + (float(tflops[i+49])-base_list[i])/base_list[i]*100)) + regression=1 + ave_perf=ave_perf+float(tflops[i+49])/base_list[i] + if regression==0: + print("no regressions found") + ave_perf=ave_perf/len(base_list) + print("average performance relative to baseline:",ave_perf) + else: + print("could not find a baseline for N=4") + + #return 0 if performance criteria met, otherwise return 1 + return regression + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/script/profile_conv.sh b/script/profile_conv.sh index f3a6d2c70c..c3ba39c926 100755 --- a/script/profile_conv.sh +++ b/script/profile_conv.sh @@ -3,9 +3,9 @@ ## GPU visibility export HIP_VISIBLE_DEVICES=0 - make -j ckProfiler +# make -j ckProfiler - DRIVER="./profiler/ckProfiler" + DRIVER="../build/bin/ckProfiler" OP=$1 DATATYPE=$2 @@ -26,7 +26,7 @@ REPEAT=$9 N=${10} -# Resnet50 from Bing +# Resnet50 (no duplicated layer) ######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 @@ -47,60 +47,60 @@ REPEAT=$9 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 -# Resnet50 from Bing -#################### op____________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 -#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +# Resnet50 fusion +####### op_________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C_ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +$DRIVER conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +$DRIVER conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 # Resnet50 diff --git a/script/profile_gemm.sh b/script/profile_gemm.sh index 036d0440e0..b816c5101f 100755 --- a/script/profile_gemm.sh +++ b/script/profile_gemm.sh @@ -1,12 +1,10 @@ #!/bin/bash ## GPU visibility - export HIP_VISIBLE_DEVICES=0 - - make -j ckProfiler - - DRIVER="./profiler/ckProfiler" - +export HIP_VISIBLE_DEVICES=0 +#make -j ckProfiler +DRIVER="../build/bin/ckProfiler" +echo $DRIVER OP=$1 DATATYPE=$2 LAYOUT=$3 @@ -43,3 +41,13 @@ $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1024 1024 1024 1088 1 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2048 2048 2048 2112 2112 2112 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 4096 4096 4096 4160 4160 4160 $DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 8192 8192 8192 8256 8256 8256 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 6656 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 3328 4096 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1664 2048 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 832 1024 1024 -1 -1 -1 + +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 7040 8192 8192 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 5120 5632 4096 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 2560 2816 2048 -1 -1 -1 +$DRIVER $OP $DATATYPE $LAYOUT $VERIFY $INIT $LOG $REPEAT 1280 1408 1024 -1 -1 -1 diff --git a/script/run_performance_tests.sh b/script/run_performance_tests.sh new file mode 100755 index 0000000000..95d63d0ffe --- /dev/null +++ b/script/run_performance_tests.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# +# in order to run this script you'd first need to build the ckProfiler executable in ../build/bin/ +# and make sure the following python packages are installed in your environment: + +pip3 install --upgrade pip +pip3 install sqlalchemy pymysql pandas sshtunnel + +# you would also need to set up some environment variables in order to +# post your new test results to the database and compare them to the baseline +# please contact Illia.Silin@amd.com for more details +# + +export gemm_log="perf_gemm.log" +rm -f $gemm_log +git status | grep -e 'On branch' > ${gemm_log} +echo -n 'Node name: ' >>${gemm_log}; hostname >> ${gemm_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${gemm_log}; rocminfo | grep "Name:" | grep "gfx" >> ${gemm_log} +rocminfo | grep "Compute Unit:" >> ${gemm_log} +hipcc --version | grep -e 'HIP version' >> ${gemm_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${gemm_log} +./profile_gemm.sh gemm 0 0 0 1 0 5 | tee -a ${gemm_log} +./profile_gemm.sh gemm 1 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 0 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 1 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 2 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 0 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 1 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 2 3 0 1 0 5 | tee -a $gemm_log +./profile_gemm.sh gemm 3 3 0 1 0 5 | tee -a $gemm_log + +python3 parse_perf_data.py ${gemm_log} + +#run resnet50 test +export resnet_log="perf_resnet50.log" +rm -f $resnet_log +git status | grep -e 'On branch' > ${resnet_log} +echo -n 'Node name: '>>${resnet_log}; hostname >>${resnet_log} +#get GPU_arch and number of compute units from rocminfo +echo -n "GPU_arch: " >> ${resnet_log}; rocminfo | grep "Name:" | grep "gfx" >> ${resnet_log} +rocminfo | grep "Compute Unit:" >> ${resnet_log} +hipcc --version | grep -e 'HIP version' >> ${resnet_log} +/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> ${resnet_log} +#first run tests with N=256 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 256 | tee -a ${resnet_log} +#then run with N=4 +./profile_conv.sh conv_fwd_bias_relu 1 1 1 1 0 2 0 1 4 | tee -a ${resnet_log} +#the script will put the results from N=256 and N=4 runs into separate tables +python3 parse_perf_data.py ${resnet_log} diff --git a/script/test_reduce_no_index.sh b/script/test_reduce_no_index.sh index 95e563c93c..b956303837 100755 --- a/script/test_reduce_no_index.sh +++ b/script/test_reduce_no_index.sh @@ -15,6 +15,17 @@ bin/test_reduce_no_index -D 64,4,280,82 -R 1 0 2 bin/test_reduce_no_index -D 64,4,280,82 -R 2 0 2 bin/test_reduce_no_index -D 64,4,280,82 -R 3 0 2 +## for float64 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_no_index -D 64,4,280,82 -R 3 6 2 + ## for float16 bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2,3 1 2 bin/test_reduce_no_index -D 64,4,280,82 -R 0,1,2 1 2 diff --git a/script/test_reduce_with_index.sh b/script/test_reduce_with_index.sh index 8e7ed33847..b0843ba6c1 100755 --- a/script/test_reduce_with_index.sh +++ b/script/test_reduce_with_index.sh @@ -15,6 +15,17 @@ bin/test_reduce_with_index -D 64,4,280,82 -R 1 0 2 bin/test_reduce_with_index -D 64,4,280,82 -R 2 0 2 bin/test_reduce_with_index -D 64,4,280,82 -R 3 0 2 +## for float64 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1,2,3 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 0 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 1 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 2 6 2 +bin/test_reduce_with_index -D 64,4,280,82 -R 3 6 2 + ## for float16 bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2,3 1 2 bin/test_reduce_with_index -D 64,4,280,82 -R 0,1,2 1 2 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index cc0778de4c..47ca0b663d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,6 +1,8 @@ include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/ ${PROJECT_SOURCE_DIR}/include/ck ${PROJECT_SOURCE_DIR}/include/ck/utility + ${PROJECT_SOURCE_DIR}/include/ck/host_utility ${PROJECT_SOURCE_DIR}/include/ck/tensor_description ${PROJECT_SOURCE_DIR}/include/ck/tensor ${PROJECT_SOURCE_DIR}/include/ck/problem_transform @@ -21,7 +23,8 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/external/include/half ) -add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR}) +include(googletest) + add_custom_target(tests) @@ -41,7 +44,7 @@ function(add_gtest_executable TEST_NAME) add_dependencies(tests ${TEST_NAME}) add_dependencies(check ${TEST_NAME}) # suppress gtest warnings - target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors) + target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(${TEST_NAME} PRIVATE gtest_main) gtest_discover_tests(${TEST_NAME}) endfunction(add_gtest_executable TEST_NAME) @@ -60,3 +63,7 @@ add_subdirectory(grouped_gemm) add_subdirectory(convnd_fwd) add_subdirectory(reduce) add_subdirectory(conv2d_bwd_weight) +add_subdirectory(convnd_bwd_data) +add_subdirectory(block_to_ctile_map) +add_subdirectory(softmax) +# DONOT add client_app, that is tested via CI independently diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp index ce061c644b..7b311cff17 100644 --- a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -22,7 +22,7 @@ int main() Row, Row, Row>( - true, 1, false, 1, M, N, K, K, N, N, BatchCount); + true, 1, false, false, M, N, K, K, N, N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, 1, M, N, K, K, K, N, BatchCount); + true, 1, false, false, M, N, K, K, K, N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, 1, M, N, K, M, N, N, BatchCount); + true, 1, false, false, M, N, K, M, N, N, BatchCount); pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( - true, 1, false, 1, M, N, K, M, K, N, BatchCount); + true, 1, false, false, M, N, K, M, K, N, BatchCount); if(pass) { diff --git a/test/block_to_ctile_map/CMakeLists.txt b/test/block_to_ctile_map/CMakeLists.txt new file mode 100644 index 0000000000..97dfbb2b55 --- /dev/null +++ b/test/block_to_ctile_map/CMakeLists.txt @@ -0,0 +1 @@ +add_gtest_executable(test_block_to_ctile_map test_block_to_ctile_map.cpp) \ No newline at end of file diff --git a/test/block_to_ctile_map/test_block_to_ctile_map.cpp b/test/block_to_ctile_map/test_block_to_ctile_map.cpp new file mode 100644 index 0000000000..662d2a0fa5 --- /dev/null +++ b/test/block_to_ctile_map/test_block_to_ctile_map.cpp @@ -0,0 +1,318 @@ +#include +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "gtest/gtest.h" +#include +#include + +using namespace ck; + +static auto I0 = Number<0>{}; +static auto I1 = Number<1>{}; +static auto I2 = Number<2>{}; + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 tile_map( + c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {0, 1, 1}, + {0, 2, 1}, + {0, 3, 0}, + {1, 0, 1}, + {1, 1, 1}, + {1, 2, 1}, + {1, 3, 0}, + {2, 0, 1}, + {2, 1, 1}, + {2, 2, 1}, + {2, 3, 0}, + {3, 0, 0}, + {3, 1, 0}, + {3, 2, 0}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N00_M01_N01_DeviceCTileIndexCheck0) +{ + const index_t M = 384; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + const index_t M01 = 4; + const index_t N01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01, N01) = (%d, %d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01, + N01); + + BlockToCTileMap_M00_N00_M01_N01 + tile_map(c_grid_desc_m_n, M01, N01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == false); +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck1) +{ + const index_t M = 384; + const index_t N = 512; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + const index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 16); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 0}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 0}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 0}, + {0, 3, 1}, + {1, 3, 1}, + {2, 3, 1}, + {3, 3, 0} + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01_DeviceCTileIndexCheck0) +{ + const index_t M = 512; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + // clang-format off + std::vector> expected_m0_gridsize_validity = { + {5, 15, false}, + {4, 12, true}, + {3, 18, false}, + {2, 12, true}, + {1, 12, true} + }; + // clang-format on + + for(auto e : expected_m0_gridsize_validity) + { + const index_t M01 = std::get<0>(e); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01 tile_map( + c_grid_desc_m_n, M01); + + EXPECT_EQ(tile_map.CalculateGridSize(c_grid_desc_m_n), std::get<1>(e)); + EXPECT_EQ(tile_map.CheckValidity(c_grid_desc_m_n), std::get<2>(e)); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_M00_N0_M01Adapt tile_map( + c_grid_desc_m_n, M01); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18); + + // clang-format off + std::vector> expected_m0idx_n0idx_valid = { + {0, 0, 1}, + {1, 0, 1}, + {2, 0, 1}, + {3, 0, 1}, + {0, 1, 1}, + {1, 1, 1}, + {2, 1, 1}, + {3, 1, 1}, + {0, 2, 1}, + {1, 2, 1}, + {2, 2, 1}, + {3, 2, 1}, + {4, 0, 1}, + {5, 0, 1}, + {4, 1, 1}, + {5, 1, 1}, + {4, 2, 1}, + {5, 2, 1}, + }; + // clang-format on + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto m0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", m0, n0 = " << m0n0_idx[I0] << ", " << m0n0_idx[I1]; + std::cout << ", valid = " << tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_m0idx_n0idx_valid[i] == + std::vector{m0n0_idx[I0], + m0n0_idx[I1], + tile_map.ValidCTileIndex(m0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} + +TEST(BlockToCTileMap, TestBlockToCTileMap_KSplit_M00_N0_M01Adapt) +{ + const index_t M = 768; + const index_t N = 384; + const index_t MPerBlock = 128; + const index_t NPerBlock = 128; + const index_t MBlock = M / MPerBlock; + const index_t NBlock = N / NPerBlock; + constexpr index_t M01 = 4; + const index_t KSplit = 3; + + auto c_grid_desc_m_n = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + + printf("(M, N, MPerBlock, NPerBlock, M01) = (%d, %d, %d, %d, %d)\n", + M, + N, + MPerBlock, + NPerBlock, + M01); + + BlockToCTileMap_KSplit_M00_N0_M01Adapt + tile_map(c_grid_desc_m_n, M01, KSplit); + + EXPECT_TRUE(tile_map.CheckValidity(c_grid_desc_m_n) == true); + EXPECT_TRUE(tile_map.CalculateGridSize(c_grid_desc_m_n) == 18 * KSplit); + + std::vector> expected_ksplitidx_m0idx_n0idx_valid = { + {0, 0, 0, 1}, {0, 1, 0, 1}, {0, 2, 0, 1}, {0, 3, 0, 1}, {0, 0, 1, 1}, {0, 1, 1, 1}, + {0, 2, 1, 1}, {0, 3, 1, 1}, {0, 0, 2, 1}, {0, 1, 2, 1}, {0, 2, 2, 1}, {0, 3, 2, 1}, + {0, 4, 0, 1}, {0, 5, 0, 1}, {0, 4, 1, 1}, {0, 5, 1, 1}, {0, 4, 2, 1}, {0, 5, 2, 1}, + {1, 0, 0, 1}, {1, 1, 0, 1}, {1, 2, 0, 1}, {1, 3, 0, 1}, {1, 0, 1, 1}, {1, 1, 1, 1}, + {1, 2, 1, 1}, {1, 3, 1, 1}, {1, 0, 2, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}, {1, 3, 2, 1}, + {1, 4, 0, 1}, {1, 5, 0, 1}, {1, 4, 1, 1}, {1, 5, 1, 1}, {1, 4, 2, 1}, {1, 5, 2, 1}, + {2, 0, 0, 1}, {2, 1, 0, 1}, {2, 2, 0, 1}, {2, 3, 0, 1}, {2, 0, 1, 1}, {2, 1, 1, 1}, + {2, 2, 1, 1}, {2, 3, 1, 1}, {2, 0, 2, 1}, {2, 1, 2, 1}, {2, 2, 2, 1}, {2, 3, 2, 1}, + {2, 4, 0, 1}, {2, 5, 0, 1}, {2, 4, 1, 1}, {2, 5, 1, 1}, {2, 4, 2, 1}, {2, 5, 2, 1}, + }; + + for(index_t i = 0; i < tile_map.CalculateGridSize(c_grid_desc_m_n); i++) + { + auto ksplitm0n0_idx = tile_map.CalculateBottomIndex(make_multi_index(i)); + std::cout << "block_1d_id = " << i << ", ksplit, m0, n0 = " << ksplitm0n0_idx[I0] << ", " + << ksplitm0n0_idx[I1] << ", " << ksplitm0n0_idx[I2]; + std::cout << ", valid = " + << tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock)) + << std::endl; + bool equal = + expected_ksplitidx_m0idx_n0idx_valid[i] == + std::vector{ksplitm0n0_idx[I0], + ksplitm0n0_idx[I1], + ksplitm0n0_idx[I2], + tile_map.ValidCTileIndex(ksplitm0n0_idx, make_tuple(MBlock, NBlock))}; + EXPECT_TRUE(equal); + } +} diff --git a/test/client_app/CMakeLists.txt b/test/client_app/CMakeLists.txt new file mode 100644 index 0000000000..f8dd8c4e0a --- /dev/null +++ b/test/client_app/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.15) +project(ck_app) +add_compile_options(-std=c++14) + +find_package(composable_kernel 1.0.0 COMPONENTS device_operations host_tensor) +find_package(hip REQUIRED PATHS /opt/rocm) +message(STATUS "Build with HIP ${hip_VERSION}") + +add_executable(test_client_app client_app.cpp) + +target_link_libraries(test_client_app PRIVATE composable_kernel::device_operations composable_kernel::host_tensor hip::host) diff --git a/test/client_app/client_app.cpp b/test/client_app/client_app.cpp new file mode 100644 index 0000000000..665a103f70 --- /dev/null +++ b/test/client_app/client_app.cpp @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "client_app_impl.hpp" + +int main(int argc, char* argv[]) +{ + if(argc != 25) + { + printf("arg1: tensor operation (conv_fwd: ForwardConvolution)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: input tensor layout (0: NCHW; 1: NHWC)\n"); + printf("arg4: weight tensor layout (0: KCYX; 1: KYXC)\n"); + printf("arg5: output tensor layout (0: NKHW; 1: NHWK)\n"); + printf("arg6: verification (0: no; 1: yes)\n"); + printf("arg7: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg9: time kernel (0=n0, 1=yes)\n"); + printf("arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(1); + } + + const ConvDataType data_type = static_cast(std::stoi(argv[2])); + const int in_layout = static_cast(std::stoi(argv[3])); + const int wei_layout = static_cast(std::stoi(argv[4])); + const int out_layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const int init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + + const ck::index_t N = std::stoi(argv[10]); + const ck::index_t K = std::stoi(argv[11]); + const ck::index_t C = std::stoi(argv[12]); + const ck::index_t Y = std::stoi(argv[13]); + const ck::index_t X = std::stoi(argv[14]); + const ck::index_t Hi = std::stoi(argv[15]); + const ck::index_t Wi = std::stoi(argv[16]); + + const ck::index_t conv_stride_h = std::stoi(argv[17]); + const ck::index_t conv_stride_w = std::stoi(argv[18]); + const ck::index_t conv_dilation_h = std::stoi(argv[19]); + const ck::index_t conv_dilation_w = std::stoi(argv[20]); + const ck::index_t in_left_pad_h = std::stoi(argv[21]); + const ck::index_t in_left_pad_w = std::stoi(argv[22]); + const ck::index_t in_right_pad_h = std::stoi(argv[23]); + const ck::index_t in_right_pad_w = std::stoi(argv[24]); + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + ck::app::profile_conv_fwd_impl(do_verification, + init_method, + do_log, + time_kernel, + data_type, + N, + K, + C, + std::vector{Hi, Wi}, + std::vector{Y, X}, + std::vector{Ho, Wo}, + std::vector{conv_stride_h, conv_stride_w}, + std::vector{conv_dilation_h, conv_dilation_w}, + std::vector{in_left_pad_h, in_left_pad_w}, + std::vector{in_right_pad_h, in_right_pad_w}); + return 1; +} diff --git a/test/client_app/client_app_impl.hpp b/test/client_app/client_app_impl.hpp new file mode 100644 index 0000000000..f9e4145ba0 --- /dev/null +++ b/test/client_app/client_app_impl.hpp @@ -0,0 +1,214 @@ +#pragma once + +#include "host_interface.hpp" + +enum ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +enum ConvInputLayout +{ + NCHW, // 0 + NHWC, // 1 +}; + +enum ConvWeightLayout +{ + KCYX, // 0 + KYXC, // 1 +}; + +enum ConvOutputLayout +{ + NKHW, // 0 + NHWK, // 1 +}; + +void check_hip_error(void) +{ + hipError_t err = hipGetLastError(); + if(err != hipSuccess) + { + std::cerr << "Error: " << hipGetErrorString(err) << std::endl; + exit(err); + } +} +std::string getDeviceName(int device) +{ + struct hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, device); + check_hip_error(); + return std::string(prop.name); +} + +int getDriver(void) +{ + int driver; + hipDriverGetVersion(&driver); + check_hip_error(); + return driver; +} + +namespace ck { +namespace app { +struct DeviceMem +{ + DeviceMem() = delete; + DeviceMem(std::size_t mem_size); + void* GetDeviceBuffer(); + void ToDevice(const void* p); + void FromDevice(void* p); + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) +{ + hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); +} + +void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } + +void DeviceMem::ToDevice(const void* p) +{ + hipGetErrorString( + hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); +} + +void DeviceMem::FromDevice(void* p) +{ + hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); +} + +DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } + +void profile_conv_fwd_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + ConvDataType data_type, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) +{ + const ck::index_t Y = filter_spatial_lengths[0]; + const ck::index_t X = filter_spatial_lengths[1]; + + const ck::index_t Hi = input_spatial_lengths[0]; + const ck::index_t Wi = input_spatial_lengths[1]; + + const ck::index_t Ho = output_spatial_lengths[0]; + const ck::index_t Wo = output_spatial_lengths[1]; + + const auto in_sz = N * C * Hi * Wi; + const auto wei_sz = K * C * Y * X; + const auto out_sz = N * K * Ho * Wo; + + using WeiDataType = float; + using InDataType = float; + using OutDataType = float; + + app::DeviceMem in_device_buf(sizeof(InDataType) * in_sz); + app::DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_sz); + app::DeviceMem out_device_buf(sizeof(OutDataType) * out_sz); + // data is already on device! + + // add device Conv instances + std::vector conv_ptrs; + if(data_type == F16_F16_F16) + { + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t(conv_ptrs); + } + else if(data_type == BF16_BF16_BF16) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t(conv_ptrs); + else if(data_type == F32_F32_F32) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t(conv_ptrs); + else if(data_type == INT8_INT8_INT8) + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t(conv_ptrs); + else + throw std::runtime_error("wrong! Invalid data type"); + if(conv_ptrs.empty()) + { + throw std::runtime_error("wrong! no device Conv instance found"); + } + + std::string best_conv_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + int deviceIndex = 0; + hipSetDevice(deviceIndex); + check_hip_error(); + + StreamConfig stream_config{nullptr, time_kernel}; + hipStreamCreate(&stream_config.stream_id_); + check_hip_error(); + + // profile device Conv instances + for(auto& conv_ptr : conv_ptrs) + { + auto argument_ptr = + conv_ptr.MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + auto invoker_ptr = conv_ptr.MakeInvokerPointer(); + + if(conv_ptr.IsSupportedArgument(argument_ptr.get())) + { + std::string conv_name = conv_ptr.GetTypeString(); + float ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv_name << std::endl; + + if(tflops > best_tflops) + { + best_conv_name = conv_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_conv_name << std::endl; +} + +} // namespace app +} // namespace ck diff --git a/test/conv2d_bwd_weight/CMakeLists.txt b/test/conv2d_bwd_weight/CMakeLists.txt index 7b515b6b8e..ecd5336c1f 100644 --- a/test/conv2d_bwd_weight/CMakeLists.txt +++ b/test/conv2d_bwd_weight/CMakeLists.txt @@ -4,4 +4,4 @@ include_directories(BEFORE ) add_test_executable(test_conv2d_bwd_weight conv2d_bwd_weight.cpp) -target_link_libraries(test_conv2d_bwd_weight PRIVATE host_tensor device_conv2d_bwd_weight_instance conv_fwd_util) +target_link_libraries(test_conv2d_bwd_weight PRIVATE host_tensor device_conv2d_bwd_weight_instance conv_util) diff --git a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp index bb3ed985e3..671980f49e 100644 --- a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp +++ b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp @@ -6,7 +6,7 @@ #include #include -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "profile_conv_bwd_weight_impl.hpp" int test_self() @@ -28,20 +28,20 @@ int test_self() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, 2); // fp16 @@ -52,28 +52,28 @@ int test_self() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, 2); } return pass; } int main(int argc, char* argv[]) { - int data_type = 0; - int init_method = 0; + int data_type = 1; + int init_method = 1; // Conv shape ck::index_t N = 128; @@ -155,20 +155,20 @@ int main(int argc, char* argv[]) ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, + true, // do_verification init_method, - 0, - 1, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, split_k); } else if(data_type == 1) @@ -180,20 +180,20 @@ int main(int argc, char* argv[]) ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, + true, // do_verification init_method, - 0, - 1, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads, + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_, split_k); } else diff --git a/test/conv_util/CMakeLists.txt b/test/conv_util/CMakeLists.txt index 70b3e851be..795c9ec0ac 100644 --- a/test/conv_util/CMakeLists.txt +++ b/test/conv_util/CMakeLists.txt @@ -1,2 +1,2 @@ add_gtest_executable(test_conv_util conv_util.cpp) -target_link_libraries(test_conv_util PRIVATE host_tensor conv_fwd_util) +target_link_libraries(test_conv_util PRIVATE host_tensor conv_util) diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 453225e800..98f55b872e 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -1,10 +1,10 @@ #include #include #include -#include "gtest/gtest.h" +#include #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "tensor_layout.hpp" #include "check_err.hpp" @@ -15,13 +15,13 @@ class TestConvUtil : public ::testing::Test public: void SetNDParams(std::size_t ndims) { - conv_params.num_dim_spatial = ndims; - conv_params.filter_spatial_lengths = std::vector(ndims, 3); - conv_params.input_spatial_lengths = std::vector(ndims, 71); - conv_params.conv_filter_strides = std::vector(ndims, 2); - conv_params.conv_filter_dilations = std::vector(ndims, 1); - conv_params.input_left_pads = std::vector(ndims, 1); - conv_params.input_right_pads = std::vector(ndims, 1); + conv_params.num_dim_spatial_ = ndims; + conv_params.filter_spatial_lengths_ = std::vector(ndims, 3); + conv_params.input_spatial_lengths_ = std::vector(ndims, 71); + conv_params.conv_filter_strides_ = std::vector(ndims, 2); + conv_params.conv_filter_dilations_ = std::vector(ndims, 1); + conv_params.input_left_pads_ = std::vector(ndims, 1); + conv_params.input_right_pads_ = std::vector(ndims, 1); } protected: @@ -44,29 +44,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths2D) std::vector{36, 36}, "Error: ConvParams 2D default constructor.")); - conv_params.conv_filter_strides = std::vector{1, 1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}.")); - conv_params.conv_filter_strides = std::vector{2, 2}; - conv_params.input_left_pads = std::vector{2, 2}; - conv_params.input_right_pads = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, std::vector{37, 37}, "Error: ConvParams 2D padding left/right {2,2}.")); - conv_params.conv_filter_dilations = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}.")); - conv_params.conv_filter_strides = std::vector{3, 3}; - conv_params.input_left_pads = std::vector{1, 1}; - conv_params.input_right_pads = std::vector{1, 1}; - conv_params.conv_filter_dilations = std::vector{2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE( ck::utils::check_err(out_spatial_len, std::vector{23, 23}, @@ -81,29 +81,29 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths1D) EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{36}, "Error: ConvParams 1D.")); - conv_params.conv_filter_strides = std::vector{1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}.")); - conv_params.conv_filter_strides = std::vector{2}; - conv_params.input_left_pads = std::vector{2}; - conv_params.input_right_pads = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{2}; + conv_params.input_left_pads_ = std::vector{2}; + conv_params.input_right_pads_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, std::vector{37}, "Error: ConvParams 1D padding left/right {2}.")); - conv_params.conv_filter_dilations = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}.")); - conv_params.conv_filter_strides = std::vector{3}; - conv_params.input_left_pads = std::vector{1}; - conv_params.input_right_pads = std::vector{1}; - conv_params.conv_filter_dilations = std::vector{2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{3}; + conv_params.input_left_pads_ = std::vector{1}; + conv_params.input_right_pads_ = std::vector{1}; + conv_params.conv_filter_dilations_ = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE( ck::utils::check_err(out_spatial_len, std::vector{23}, @@ -118,31 +118,31 @@ TEST_F(TestConvUtil, ConvParamsGetOutputSpatialLengths3D) EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D.")); - conv_params.conv_filter_strides = std::vector{1, 1, 1}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{1, 1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, std::vector{71, 71, 71}, "Error: ConvParams 3D stride {1, 1, 1}.")); - conv_params.conv_filter_strides = std::vector{2, 2, 2}; - conv_params.input_left_pads = std::vector{2, 2, 2}; - conv_params.input_right_pads = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{2, 2, 2}; + conv_params.input_left_pads_ = std::vector{2, 2, 2}; + conv_params.input_right_pads_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, std::vector{37, 37, 37}, "Error: ConvParams 3D padding left/right {2, 2, 2}.")); - conv_params.conv_filter_dilations = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err(out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D dilation {2, 2, 2}.")); - conv_params.conv_filter_strides = std::vector{3, 3, 3}; - conv_params.input_left_pads = std::vector{1, 1, 1}; - conv_params.input_right_pads = std::vector{1, 1, 1}; - conv_params.conv_filter_dilations = std::vector{2, 2, 2}; - out_spatial_len = conv_params.GetOutputSpatialLengths(); + conv_params.conv_filter_strides_ = std::vector{3, 3, 3}; + conv_params.input_left_pads_ = std::vector{1, 1, 1}; + conv_params.input_right_pads_ = std::vector{1, 1, 1}; + conv_params.conv_filter_dilations_ = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); EXPECT_TRUE(ck::utils::check_err( out_spatial_len, std::vector{23, 23, 23}, diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index 58e6e7d3d0..55d71a41d3 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -4,4 +4,4 @@ include_directories(BEFORE ) add_test_executable(test_convnd_bwd_data convnd_bwd_data.cpp) -target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_fwd_util) +target_link_libraries(test_convnd_bwd_data PRIVATE host_tensor device_convnd_bwd_data_instance conv_util) diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp index cbc215033b..7284680e0e 100644 --- a/test/convnd_bwd_data/convnd_bwd_data.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -27,20 +27,20 @@ int main() ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::NWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<1, ck::half_t, @@ -50,20 +50,20 @@ int main() ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::NWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<1, ck::bhalf_t, @@ -73,20 +73,20 @@ int main() ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::NWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<1, int8_t, @@ -96,20 +96,20 @@ int main() ck::tensor_layout::convolution::NWC, ck::tensor_layout::convolution::KXC, ck::tensor_layout::convolution::NWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); } // check 2d @@ -128,20 +128,20 @@ int main() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<2, ck::half_t, @@ -151,20 +151,20 @@ int main() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<2, ck::bhalf_t, @@ -174,20 +174,20 @@ int main() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<2, int8_t, @@ -197,20 +197,20 @@ int main() ck::tensor_layout::convolution::NHWC, ck::tensor_layout::convolution::KYXC, ck::tensor_layout::convolution::NHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); } // check 3d @@ -232,20 +232,20 @@ int main() ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::NDHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<3, ck::half_t, @@ -255,20 +255,20 @@ int main() ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::NDHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<3, ck::bhalf_t, @@ -278,20 +278,20 @@ int main() ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::NDHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); pass &= ck::profiler::profile_convnd_bwd_data_impl<3, int8_t, @@ -301,20 +301,20 @@ int main() ck::tensor_layout::convolution::NDHWC, ck::tensor_layout::convolution::KZYXC, ck::tensor_layout::convolution::NDHWK>( - 1, // do_verification, - 1, // init_method, - 0, // do_log, - 1, // nrepeat, - param.N, - param.K, - param.C, - param.input_spatial_lengths, - param.filter_spatial_lengths, + true, // do_verification + 1, // init_method + false, // do_log + false, // time_kernel + param.N_, + param.K_, + param.C_, + param.input_spatial_lengths_, + param.filter_spatial_lengths_, param.GetOutputSpatialLengths(), - param.conv_filter_strides, - param.conv_filter_dilations, - param.input_left_pads, - param.input_right_pads); + param.conv_filter_strides_, + param.conv_filter_dilations_, + param.input_left_pads_, + param.input_right_pads_); } if(pass) diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 1d2ae3e4e3..444ec6c8aa 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,13 +1,13 @@ add_custom_target(test_convnd_fwd) add_gtest_executable(test_conv1d_fwd conv1d_fwd.cpp) -target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_fwd_util) +target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_instance conv_util) add_dependencies(test_convnd_fwd test_conv1d_fwd) add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) -target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_fwd_util) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance device_convnd_2d_fwd_instance conv_util) add_dependencies(test_convnd_fwd test_conv2d_fwd) add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) -target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_fwd_util) +target_link_libraries(test_conv3d_fwd PRIVATE host_tensor device_conv3d_fwd_instance conv_util) add_dependencies(test_convnd_fwd test_conv3d_fwd) diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index c161b2795e..9b4708e94b 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -1,93 +1,189 @@ #include -#include #include #include #include "gtest/gtest.h" #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_fwd_util.hpp" +#include "library/include/ck/library/utility/conv_util.hpp" #include "conv_util.hpp" namespace { -template -bool test_conv1d_nwc_instances(const std::vector& conv_ptrs) +class Conv1dFwdNWCInstances : public ::testing::Test +{ + public: + template + bool test_conv1d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv1DFwdNWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = float; - ck::utils::conv::ConvParams params; - params.num_dim_spatial = 1; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{71}; - params.conv_filter_strides = std::vector{2}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv1DFwdNWC, TestConv1D) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - ck::utils::conv::ConvParams params; - params.num_dim_spatial = 1; - params.N = 2; - params.K = 16; - params.C = 4; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{16}; - params.conv_filter_strides = std::vector{1}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(1e-5); run_engine.SetRtol(1e-4); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv1DFwdNWC, Bf16Iinstances) +TEST(Conv1DFwdNWC, FloatingPointValues) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(0.1); + run_engine.SetRtol(1e-2); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv1DFwdNWC, F16Instances) +TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv1DFwdNWC, F32Instances) +TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv1DFwdNWC, Int8Instances) +TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index e3815f778a..4e0238cc4f 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -1,91 +1,265 @@ -#include -#include #include #include #include "gtest/gtest.h" +#include "ck/library/utility/conv_util.hpp" +#include "config.hpp" +#include "conv_util.hpp" #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_fwd_util.hpp" -#include "conv_util.hpp" +#include "fill.hpp" namespace { -template -bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) +class Conv2dFwdNHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_default_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_default_); + } + } + + template + bool test_filter1x1_stride1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), + params_filter1x1_stride1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + } + + template + bool test_filter1x1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_filter1x1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + } + + template + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv2DFwdNHWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; + using T = float; - conv::ConvParams params; - params.num_dim_spatial = 2; - params.filter_spatial_lengths = std::vector{3, 3}; - params.input_spatial_lengths = std::vector{71, 71}; - params.conv_filter_strides = std::vector{2, 2}; - params.conv_filter_dilations = std::vector{1, 1}; - params.input_left_pads = std::vector{1, 1}; - params.input_right_pads = std::vector{1, 1}; + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv2DFwdNHWC, TestConv2D) -{ - using namespace std::placeholders; - using namespace ck::utils; - - ck::utils::conv::ConvParams params; - params.N = 2; - params.K = 16; - params.C = 4; - params.input_spatial_lengths = std::vector{16, 16}; - params.conv_filter_strides = std::vector{1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(1e-5); run_engine.SetRtol(1e-4); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv2DFwdNHWC, Bf16Instances) +TEST(Conv2DFwdNHWC, FloatingPointValues) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + using namespace std::placeholders; + using namespace ck::utils; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(2e-4); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv2DFwdNHWC, F16Instances) +TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC()); } +TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv2DFwdNHWC, BF32Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_default(true)); } - -TEST(Conv2DFwdNHWC, F32Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); } - -TEST(Conv2DFwdNHWC, Int8Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); } diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index fc3da3e9c7..2470727fd7 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -7,66 +7,148 @@ #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_fwd_util.hpp" +#include "library/include/ck/library/utility/conv_util.hpp" #include "conv_util.hpp" namespace { -template -bool test_conv3d_ndhwc_instances(const std::vector& conv_ptrs) +class Conv3dFwdNDHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv3d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv3DFwdNDHWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = float; - conv::ConvParams params; - params.N = 64; - params.num_dim_spatial = 3; - params.filter_spatial_lengths = std::vector{3, 3, 2}; - params.input_spatial_lengths = std::vector{32, 32, 2}; - params.conv_filter_strides = std::vector{2, 2, 2}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -} // anonymous namespace - -TEST(Conv3DFwdNDHWC, TestConv3D) +TEST(Conv3DFwdNDHWC, FloatingPointValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; - conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 4; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{16, 16, 16}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); + test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - run_engine.SetAtol(1e-5); - run_engine.SetRtol(1e-4); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-3); + run_engine.SetRtol(1e-3); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } @@ -74,36 +156,36 @@ TEST(Conv3DFwdNDHWC, InputOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Input conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 32; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{32, 1000, 1000}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{32, 1000, 1000}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, PassThrough{}, PassThrough{}, PassThrough{}); @@ -114,36 +196,36 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Filters conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 32; - params.filter_spatial_lengths = std::vector{4, 1000, 1000}; - params.input_spatial_lengths = std::vector{16, 16, 16}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 32; + params.filter_spatial_lengths_ = std::vector{4, 1000, 1000}; + params.input_spatial_lengths_ = std::vector{16, 16, 16}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{1, 1, 1}; + params.input_right_pads_ = std::vector{1, 1, 1}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, PassThrough{}, PassThrough{}, PassThrough{}); @@ -154,61 +236,78 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Output conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 2; - params.filter_spatial_lengths = std::vector{1, 1, 1}; - params.input_spatial_lengths = std::vector{1000, 1000, 30}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{2, 2, 2}; - params.input_right_pads = std::vector{2, 2, 2}; + params.num_dim_spatial_ = 3; + params.N_ = 2; + params.K_ = 16; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{1, 1, 1}; + params.input_spatial_lengths_ = std::vector{1000, 1000, 30}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{2, 2, 2}; + params.input_right_pads_ = std::vector{2, 2, 2}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, params.GetOutputSpatialLengths(), - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, PassThrough{}, PassThrough{}, PassThrough{}); EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); } -TEST(Conv3DFwdNDHWC, Bf16Instances) +TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, F16Instances) +TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, F32Instances) +TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, Int8Instances) +TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp index 4f77101563..1ec83bd118 100644 --- a/test/convnd_fwd/conv_util.hpp +++ b/test/convnd_fwd/conv_util.hpp @@ -1,15 +1,33 @@ -#ifndef TEST_CONV_UTIL_HPP -#define TEST_CONV_UTIL_HPP +#pragma once #include #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "data_type.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "sequence.hpp" +namespace ck { +namespace tensor_operation { +namespace device { + +using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; +namespace device_conv2d_fwd_instance { + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + namespace test { namespace conv { @@ -26,57 +44,128 @@ using DeviceConvFwdNoOpPtr = static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -template +template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off InDataType, // WeiDataType, // OutDataType, // - InDataType, // + AccDataType, // Accumulator data type. InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation ConvFwdDefault, // ConvForwardSpecialization SpatialDims, // SptialDims - 64, // BlockSize - 16, // MPerBlock - 16, // NPerBlock + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock 4, // K0PerBlock - 1, // K1 - 16, // MPerXDL - 16, // NPerXDL - 1, // MXdlPerWave - 1, // NXdlPerWave - S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector - 1, // ABlockTransferDstScalarPerVector_K1 + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 true, // ABlockLdsAddExtraM - S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 1, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockTransferAddExtraN + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector + 1>; // CThreadTransferDstScalarPerVector // clang-format on template + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType> void get_test_convolution_fwd_instance(std::vector& instances) { - using ConvInstanceT = DeviceConvNDFwdInstance; + using ConvInstanceT = + DeviceConvNDFwdInstance; instances.emplace_back(std::make_unique()); } +// TODO (aosewski) +// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances +// structures. +template +struct ConvolutionNDFwdInstances; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + } // namespace conv } // namespace test - -#endif diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 83b3c1e2e3..b8679e3715 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,15 +1,29 @@ -add_test_executable(test_gemm_fp32 gemm_fp32.cpp) -target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) +# GEMM XDL +add_test_executable(test_gemm_xdl_fp32 gemm_xdl_fp32.cpp) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp32 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_fp16 gemm_fp16.cpp) -target_link_libraries(test_gemm_fp16 PRIVATE host_tensor) -target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_fp16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_bf16 gemm_bf16.cpp) -target_link_libraries(test_gemm_bf16 PRIVATE host_tensor) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_bf16 PRIVATE device_gemm_instance) -add_test_executable(test_gemm_int8 gemm_int8.cpp) -target_link_libraries(test_gemm_int8 PRIVATE host_tensor) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) +add_test_executable(test_gemm_xdl_int8 gemm_xdl_int8.cpp) +target_link_libraries(test_gemm_xdl_int8 PRIVATE host_tensor) +target_link_libraries(test_gemm_xdl_int8 PRIVATE device_gemm_instance) + +# GEMM DL +add_test_executable(test_gemm_dl_fp32 gemm_dl_fp32.cpp) +target_link_libraries(test_gemm_dl_fp32 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp32 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_fp16 gemm_dl_fp16.cpp) +target_link_libraries(test_gemm_dl_fp16 PRIVATE host_tensor) +target_link_libraries(test_gemm_dl_fp16 PRIVATE device_gemm_instance) + +add_test_executable(test_gemm_dl_int8 gemm_dl_int8.cpp) +target_link_libraries(test_gemm_dl_int8 PRIVATE host_tensor) +TArget_link_libraries(test_gemm_dl_int8 PRIVATE device_gemm_instance) diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_dl_fp16.cpp similarity index 78% rename from test/gemm/gemm_int8.cpp rename to test/gemm/gemm_dl_fp16.cpp index 870881dd76..8a539372ba 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_dl_fp16.cpp @@ -1,132 +1,135 @@ -#include -#include -#include -#include -#include -#include -#include - -#include "gemm_util.hpp" -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_cshuffle.hpp" -#include "element_wise_operation.hpp" -#include "reference_gemm.hpp" -#include "gemm_specialization.hpp" - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using DeviceGemmNoOpPtr = - ck::tensor_operation::device::DeviceGemmPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( - std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( - std::vector&); -} // namespace device_gemm_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -int main() -{ - using ADataType = int8_t; - using BDataType = int8_t; - using CDataType = int8_t; - - using RowMajor = ck::tensor_layout::gemm::RowMajor; - using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; - - std::vector gemmPtrs; - bool res = true; - - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - gemmPtrs.clear(); - ck::tensor_operation::device::device_gemm_instance:: - add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(gemmPtrs); - - for(auto& gemmPtr : gemmPtrs) - { - res &= ck::gemm_util::TestGemm{}(gemmPtr); - } - - std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - return res ? 0 : 1; -} +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = ck::half_t; + using BDataType = ck::half_t; + using CDataType = ck::half_t; + using AccDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + + std::vector gemmPtrs; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_fp32.cpp b/test/gemm/gemm_dl_fp32.cpp new file mode 100644 index 0000000000..3484458042 --- /dev/null +++ b/test/gemm/gemm_dl_fp32.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = float; + using BDataType = float; + using CDataType = float; + using AccDataType = float; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_dl_int8.cpp b/test/gemm/gemm_dl_int8.cpp new file mode 100644 index 0000000000..5dfb7221cb --- /dev/null +++ b/test/gemm/gemm_dl_int8.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "../gemm/gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_dl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + using AccDataType = int; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 17e954b7f2..a3cafa6df1 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -60,7 +60,7 @@ template -void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, +bool RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, const ck::gemm_util::GemmParams& params, const Tensor& A, const Tensor& B, @@ -73,9 +73,6 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, DeviceMem b_k_n_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpace()); - a_m_k_device_buf.ToDevice(A.mData.data()); - b_k_n_device_buf.ToDevice(B.mData.data()); - auto invoker_ptr = gemmPtr->MakeInvokerPointer(); auto argument_ptr = gemmPtr->MakeArgumentPointer(static_cast(a_m_k_device_buf.GetDeviceBuffer()), @@ -91,21 +88,30 @@ void RunDeviceGEMM(DeviceGemmPtr_& gemmPtr, b_element_op, c_element_op); - if(!gemmPtr->IsSupportedArgument(argument_ptr.get())) + if(gemmPtr->IsSupportedArgument(argument_ptr.get())) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } + a_m_k_device_buf.ToDevice(A.mData.data()); + b_k_n_device_buf.ToDevice(B.mData.data()); + invoker_ptr->Run(argument_ptr.get()); + c_m_n_device_buf.FromDevice(C.mData.data()); - invoker_ptr->Run(argument_ptr.get()); - c_m_n_device_buf.FromDevice(C.mData.data()); + return true; + } + else + { + std::cout << "device_gemm with the specified compilation parameters does " + "not support this GEMM problem" + << std::endl; + + return false; + } } template ; @@ -188,28 +195,40 @@ struct TestGemm a, b, c_host, a_element_op, b_element_op, c_element_op); // Act - ck::gemm_util::RunDeviceGEMM( + bool is_supported = ck::gemm_util::RunDeviceGEMM( gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); - // Assert - bool res = false; - if(std::is_same::value) + if(is_supported) { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } - else if(std::is_same::value) - { - res = ck::utils::check_err(c_device.mData, c_host.mData); - std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; - } + // Assert + bool res = false; + if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } + else if(std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; + } - return res; + return res; + } + else + { + return true; + } } }; @@ -299,6 +318,7 @@ struct TestGemmBF16 // use fp32 host kernel to verify bf16 device kernel using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string name(props.gcnArchName); + + return name; +} + +int main() +{ + if(get_device_name().find("gfx90a") == std::string::npos) + { + std::cout << "TestGemm ..... SUCCESS" << std::endl; + return 0; + } + using ADataType = double; + using BDataType = double; + using CDataType = double; + using AccDataType = double; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + bool res = true; + std::vector gemmPtrs; + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm/gemm_xdl_int8.cpp b/test/gemm/gemm_xdl_int8.cpp new file mode 100644 index 0000000000..0075b79cf7 --- /dev/null +++ b/test/gemm/gemm_xdl_int8.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "gemm_util.hpp" +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdl.hpp" +#include "device_gemm_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using DeviceGemmNoOpPtr = + ck::tensor_operation::device::DeviceGemmPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(std::vector&); +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +int main() +{ + using ADataType = int8_t; + using BDataType = int8_t; + using CDataType = int8_t; + using AccDataType = int32_t; + + using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ColumnMajor = ck::tensor_layout::gemm::ColumnMajor; + + std::vector gemmPtrs; + bool res = true; + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + gemmPtrs.clear(); + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(gemmPtrs); + + for(auto& gemmPtr : gemmPtrs) + { + res &= ck::gemm_util::TestGemm{}(gemmPtr); + } + + std::cout << "TestGemm ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return res ? 0 : 1; +} diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp index 8deb66b2b0..6c7bb9658f 100644 --- a/test/gemm_reduce/gemm_reduce_fp16.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -16,22 +16,22 @@ int main() pass = pass && ck::profiler:: profile_gemm_reduce_impl( - true, 1, false, 1, M, N, K, K, N, N); + true, 1, false, false, M, N, K, K, N, N); pass = pass && ck::profiler:: profile_gemm_reduce_impl( - true, 1, false, 1, M, N, K, K, K, N); + true, 1, false, false, M, N, K, K, K, N); pass = pass && ck::profiler:: profile_gemm_reduce_impl( - true, 1, false, 1, M, N, K, M, N, N); + true, 1, false, false, M, N, K, M, N, N); pass = pass && ck::profiler:: profile_gemm_reduce_impl( - true, 1, false, 1, M, N, K, M, K, N); + true, 1, false, false, M, N, K, M, K, N); if(pass) { diff --git a/test/gemm_split_k/gemm_split_k.cpp b/test/gemm_split_k/gemm_split_k.cpp index a3d4f9b2ec..b63361aa1b 100644 --- a/test/gemm_split_k/gemm_split_k.cpp +++ b/test/gemm_split_k/gemm_split_k.cpp @@ -45,7 +45,7 @@ static bool check_out(const Tensor& ref, const Tensor& result) { float max_diff = 1e-6; - for(int i = 0; i < ref.mData.size(); ++i) + for(std::size_t i = 0; i < ref.mData.size(); ++i) { float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); if(max_diff < diff) @@ -187,9 +187,10 @@ int test_gemm(const gemmArgs& args) if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) { - invoker_ptr->Run(argument_ptr.get(), 0); + invoker_ptr->Run(argument_ptr.get()); c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + if(!check_out(c_m_n_host_result, c_m_n_device_result)) { success = false; diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index 2260b01462..fc8ec66b51 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -104,7 +104,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) b_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { a_tensors.emplace_back(Tensor(f_host_tensor_descriptor( gemm_shapes[i].M, gemm_shapes[i].K, gemm_shapes[i].StrideA, ALayout{}))); @@ -119,7 +119,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); } - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { a_tensors_device.emplace_back( std::make_unique(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize())); @@ -141,18 +141,28 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) auto c_element_op = PassThrough{}; // do GEMM - auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto invoker_ptr = groupedGemmPtr->MakeInvokerPointer(); + auto argument_ptr = groupedGemmPtr->MakeArgumentPointer( p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op); + DeviceMem gemm_desc_workspace(groupedGemmPtr->GetWorkSpaceSize(argument_ptr.get())); + + groupedGemmPtr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + invoker_ptr->Run(argument_ptr.get()); - for(int i = 0; i < gemm_shapes.size(); i++) + for(std::size_t i = 0; i < gemm_shapes.size(); i++) { c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 28370cb2cd..20030392b5 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -1,384 +1,10 @@ #include "getopt.h" -#include "check_err.hpp" -#include "device_reduce_instance.hpp" -#include "reduction_enums.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_reduction.hpp" -#include "reduce_util.hpp" +#include "host_common_util.hpp" +#include "profile_reduce_impl.hpp" using namespace ck; -namespace { - -template -static inline std::vector get_invariant_dims(const std::vector& reduceDims) -{ - assert(NumReduceDim == reduceDims.size()); - - int reduceFlag = 0; - - // flag the bits for the reduceDims - for(int i = 0; i < NumReduceDim; i++) - { - reduceFlag |= 1 << reduceDims[i]; - }; - - std::vector invariantDims; - - // collect invariant dimensions - for(int i = 0; i < Rank; i++) - if((reduceFlag & (1 << i)) == 0) - { - invariantDims.push_back(i); - }; - - return invariantDims; -}; - -constexpr int Rank = 4; - -constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AVG; -constexpr NanPropagation NanOpt = NanPropagation::PROPAGATE_NAN; -constexpr bool PropagateNan = false; -constexpr ReduceTensorIndices IndicesOpt = ReduceTensorIndices::NO_INDICES; -constexpr bool NeedIndices = false; - -template -bool test_reduce_no_index_impl(int init_method, - const std::vector& inLengths, - const std::vector& reduceDims, - float alpha, - float beta) -{ - using namespace ck::tensor_operation::device; - using namespace ck::tensor_operation::device::device_reduce_instance; - using namespace ck::host_reduce; - - constexpr bool out_support_atomic_add = std::is_same::value; - constexpr bool op_support_atomic_add = true; - constexpr bool use_atomic_add = (out_support_atomic_add && op_support_atomic_add); - - Tensor in(inLengths); - - std::vector outLengths; - - const auto invariantDims = get_invariant_dims(reduceDims); - - if(reduceDims.size() == Rank) - outLengths.push_back(1); - else - for(auto dim : invariantDims) - outLengths.push_back(inLengths[dim]); - - Tensor out_ref(outLengths); - Tensor out(outLengths); - - // only used when the OutDataType is bhalf_t - Tensor out_ref_fp32(outLengths); - Tensor out_fp32(outLengths); - - auto inStrides = in.mDesc.GetStrides(); - auto outStrides = out.mDesc.GetStrides(); - - size_t invariant_total_length = out.mDesc.GetElementSize(); - size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - } - - if(beta != 0.0f) - for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) - out.mData[i] = out_ref.mData[i]; - - // these buffers are usually provided by the user application - DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); - DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); - - in_dev.ToDevice(in.mData.data()); - - if(beta != 0.0f) - out_dev.ToDevice(out.mData.data()); - - using InElementwiseOperation_0 = - typename reduce_unary_operator::InElementwiseOperation; - using AccElementwiseOperation_0 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_1 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_1 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_2 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_2 = - typename reduce_unary_operator:: - AccElementwiseOperation; - - using DeviceReduceInstPtr0 = - DeviceReducePtr; - using DeviceReduceInstPtr1 = - DeviceReducePtr; - using DeviceReduceInstPtr2 = - DeviceReducePtr; - - std::vector reduce0_ptrs; - std::vector reduce1_ptrs; - std::vector reduce2_ptrs; - - add_device_reduce_instance_threadwise(reduce0_ptrs); - - add_device_reduce_instance_blockwise(reduce0_ptrs); - - if constexpr(use_atomic_add) - { - add_device_reduce_instance_multiblock_atomic_add(reduce0_ptrs); - } - else - { - add_device_reduce_instance_multiblock_partial_reduce(reduce1_ptrs); - }; - - // used for secondary reduction - if constexpr(!use_atomic_add) - { - add_device_reduce_instance_blockwise_second_call(reduce2_ptrs); - }; - - if(reduce0_ptrs.empty() && reduce1_ptrs.empty()) - { - throw std::runtime_error("Wrong! No device REDUCE instance found"); - }; - - bool result = true; - - ReductionHost - hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - - hostReduce.Run(alpha, in.mData.data(), beta, out_ref.mData.data(), nullptr); - - const auto i_inLengths = to_int_vector(inLengths); - const auto i_inStrides = to_int_vector(inStrides); - const auto i_outLengths = to_int_vector(outLengths); - const auto i_outStrides = to_int_vector(outStrides); - - for(auto& reduce_ptr : reduce0_ptrs) - { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_0 in_elementwise_op_0(static_cast(reduce_total_length)); - AccElementwiseOperation_0 acc_elementwise_op_0(static_cast(reduce_total_length)); - - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - nullptr, - ws_dev.GetDeviceBuffer(), - in_elementwise_op_0, - acc_elementwise_op_0); - - if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) - continue; - - auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - - (void)invoker_ptr->Run(argument_ptr.get()); - - out_dev.FromDevice(out.mData.data()); - - bool single_result = true; - - if constexpr(std::is_same::value || - std::is_same::value) - { - reduce_util::to_f32_vector(out, out_fp32); - reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = ck::utils::check_err( - out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); - } - else - { - single_result = - ck::utils::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); - }; - - if(!single_result) - { - std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; - result = false; - } - }; - - for(auto& reduce_ptr : reduce1_ptrs) - { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_1 in_elementwise_op_1(static_cast(reduce_total_length)); - AccElementwiseOperation_1 acc_elementwise_op_1(static_cast(reduce_total_length)); - - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - nullptr, - ws_dev.GetDeviceBuffer(), - in_elementwise_op_1, - acc_elementwise_op_1); - - if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) - continue; - - auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - - (void)invoker_ptr->Run(argument_ptr.get()); - - std::vector inLengths2 = reduce_ptr->GetWorkspace2dLengths(argument_ptr.get()); - std::vector inStrides2{inLengths2[1], 1}; - - for(auto& reduce2_ptr : reduce2_ptrs) - { - InElementwiseOperation_2 in_elementwise_op_2(static_cast(reduce_total_length)); - AccElementwiseOperation_2 acc_elementwise_op_2( - static_cast(reduce_total_length)); - - auto argument2_ptr = reduce2_ptr->MakeArgumentPointer(inLengths2, - inStrides2, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - ws_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - nullptr, - ws_dev.GetDeviceBuffer(), - in_elementwise_op_2, - acc_elementwise_op_2); - - if(!reduce2_ptr->IsSupportedArgument(argument2_ptr.get())) - continue; - - std::string reduce2_name = reduce2_ptr->GetTypeString(); - - auto invoker2_ptr = reduce2_ptr->MakeInvokerPointer(); - - (void)invoker2_ptr->Run(argument2_ptr.get()); - - out_dev.FromDevice(out.mData.data()); - - bool single_result = true; - - if constexpr(std::is_same::value || - std::is_same::value) - { - reduce_util::to_f32_vector(out, out_fp32); - reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = ck::utils::check_err( - out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); - } - else - { - single_result = - ck::utils::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); - }; - - if(!single_result) - { - std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << " => " - << reduce2_ptr->GetTypeString() << std::endl; - result = false; - } - }; - }; - - return (result); -}; - -} // anonymous namespace - static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, {"reduceDimensions", required_argument, nullptr, 'R'}, {"scales", required_argument, nullptr, 'S'}, @@ -387,48 +13,6 @@ static struct option long_options[] = {{"inLengths", required_argument, nullptr, class SimpleAppArgs { - template - static T getSingleValueFromString(const std::string& valueStr) - { - std::istringstream iss(valueStr); - - T ret; - - iss >> ret; - - return (ret); - }; - - template - static std::vector getTypeValuesFromString(const char* cstr_values) - { - std::string valuesStr(cstr_values); - - std::vector values; - std::size_t pos = 0; - std::size_t new_pos; - - new_pos = valuesStr.find(',', pos); - while(new_pos != std::string::npos) - { - const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); - - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - pos = new_pos + 1; - new_pos = valuesStr.find(',', pos); - }; - - std::string sliceStr = valuesStr.substr(pos); - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - return (values); - }; - private: int option_index = 0; @@ -460,7 +44,9 @@ class SimpleAppArgs int processArgs(int argc, char* argv[]) { - unsigned int ch; + using ck::host_common::getTypeValuesFromString; + + int ch; while(1) { @@ -514,7 +100,7 @@ class SimpleAppArgs (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) return (-1); - if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5) + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) return (-1); return (0); @@ -525,87 +111,92 @@ bool test_reduce_no_index(int data_type, int init_method, std::vector reduceDims, std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, float alpha, float beta) { + using ck::profiler::profile_reduce_impl; + bool result = true; if(data_type == 0) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); } else if(data_type == 1) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); } else if(data_type == 3) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); } else if(data_type == 5) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_no_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + false, + alpha, + beta); } return (result); }; +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AVG; +constexpr bool propagateNan = false; + int main(int argc, char* argv[]) { SimpleAppArgs args; @@ -621,8 +212,14 @@ int main(int argc, char* argv[]) {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; for(auto& reduceDims : v_reduceDims) - result = result && test_reduce_no_index( - data_type, init_method, reduceDims, inLengths, 1.0f, 0.0f); + result = result && test_reduce_no_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); } else { @@ -636,6 +233,8 @@ int main(int argc, char* argv[]) args.init_method, args.reduceDims, args.inLengths, + reduceOpId, + propagateNan, args.scales[0], args.scales[1]); } diff --git a/test/reduce/reduce_util.hpp b/test/reduce/reduce_util.hpp deleted file mode 100644 index e9a7b4896e..0000000000 --- a/test/reduce/reduce_util.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef REDUCE_UTILS_HPP -#define REDUCE_UTILS_HPP - -#include "data_type.hpp" - -namespace ck { -namespace reduce_util { - -template -void to_f32_vector(const Tensor& src, Tensor& dst) -{ - for(int i = 0; i < src.mData.size(); ++i) - dst.mData[i] = type_convert(src.mData[i]); -} - -} // namespace reduce_util - -} // namespace ck -#endif diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index 667b84a8dc..c1918bf388 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -1,387 +1,10 @@ #include "getopt.h" -#include "device_reduce_instance.hpp" -#include "reduction_enums.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_reduction.hpp" -#include "check_err.hpp" -#include "reduce_util.hpp" + +#include "host_common_util.hpp" +#include "profile_reduce_impl.hpp" using namespace ck; -namespace { - -template -static inline std::vector get_invariant_dims(const std::vector& reduceDims) -{ - assert(NumReduceDim == reduceDims.size()); - - int reduceFlag = 0; - - // flag the bits for the reduceDims - for(int i = 0; i < NumReduceDim; i++) - { - reduceFlag |= 1 << reduceDims[i]; - }; - - std::vector invariantDims; - - // collect invariant dimensions - for(int i = 0; i < Rank; i++) - if((reduceFlag & (1 << i)) == 0) - { - invariantDims.push_back(i); - }; - - return invariantDims; -}; - -constexpr int Rank = 4; - -constexpr ReduceTensorOp ReduceOpId = ReduceTensorOp::AMAX; -constexpr NanPropagation NanOpt = NanPropagation::PROPAGATE_NAN; -constexpr bool PropagateNan = false; -constexpr ReduceTensorIndices IndicesOpt = ReduceTensorIndices::FLATTENED_INDICES; -constexpr bool NeedIndices = true; - -template -bool test_reduce_with_index_impl(int init_method, - const std::vector& inLengths, - const std::vector& reduceDims, - float alpha, - float beta) -{ - using namespace ck::tensor_operation::device; - using namespace ck::tensor_operation::device::device_reduce_instance; - using namespace ck::host_reduce; - - Tensor in(inLengths); - - std::vector outLengths; - - const auto invariantDims = get_invariant_dims(reduceDims); - - if(reduceDims.size() == Rank) - outLengths.push_back(1); - else - for(auto dim : invariantDims) - outLengths.push_back(inLengths[dim]); - - Tensor out_ref(outLengths); - Tensor out(outLengths); - Tensor out_indices_ref(outLengths); - Tensor out_indices(outLengths); - - // only used when the OutDataType is bhalf_t - Tensor out_ref_fp32(outLengths); - Tensor out_fp32(outLengths); - - auto inStrides = in.mDesc.GetStrides(); - auto outStrides = out.mDesc.GetStrides(); - - size_t invariant_total_length = out.mDesc.GetElementSize(); - size_t reduce_total_length = in.mDesc.GetElementSize() / invariant_total_length; - - std::size_t num_thread = 1; - - switch(init_method) - { - case 0: break; - case 1: - in.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - break; - case 2: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - in.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - if(beta != 0.0f) - out_ref.GenerateTensorValue(GeneratorTensor_3{-5.0, 5.0}, num_thread); - } - - if(beta != 0.0f) - for(size_t i = 0; i < out_ref.mDesc.GetElementSpace(); i++) - out.mData[i] = out_ref.mData[i]; - - // these buffers are usually provided by the user application - DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); - DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); - - in_dev.ToDevice(in.mData.data()); - - if(beta != 0.0f) - out_dev.ToDevice(out.mData.data()); - - size_t indicesSizeInBytes = NeedIndices ? out.mDesc.GetElementSize() * sizeof(int) : 0; - - DeviceMem out_indices_dev(indicesSizeInBytes); - - using InElementwiseOperation_0 = - typename reduce_unary_operator::InElementwiseOperation; - using AccElementwiseOperation_0 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_1 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_1 = - typename reduce_unary_operator:: - AccElementwiseOperation; - using InElementwiseOperation_2 = - typename reduce_unary_operator:: - InElementwiseOperation; - using AccElementwiseOperation_2 = - typename reduce_unary_operator:: - AccElementwiseOperation; - - using DeviceReduceInstPtr0 = - DeviceReducePtr; - using DeviceReduceInstPtr1 = - DeviceReducePtr; - using DeviceReduceInstPtr2 = - DeviceReducePtr; - - std::vector reduce0_ptrs; - std::vector reduce1_ptrs; - std::vector reduce2_ptrs; - - add_device_reduce_instance_threadwise(reduce0_ptrs); - - add_device_reduce_instance_blockwise(reduce0_ptrs); - - add_device_reduce_instance_multiblock_partial_reduce(reduce1_ptrs); - - add_device_reduce_instance_blockwise_second_call(reduce2_ptrs); - - if(reduce0_ptrs.empty() && reduce1_ptrs.empty()) - { - throw std::runtime_error("Wrong! No device REDUCE instance found"); - }; - - bool result = true; - - ReductionHost - hostReduce(in.mDesc, out_ref.mDesc, invariantDims, reduceDims); - - hostReduce.Run( - alpha, in.mData.data(), beta, out_ref.mData.data(), out_indices_ref.mData.data()); - - const auto i_inLengths = to_int_vector(inLengths); - const auto i_inStrides = to_int_vector(inStrides); - const auto i_outLengths = to_int_vector(outLengths); - const auto i_outStrides = to_int_vector(outStrides); - - for(auto& reduce_ptr : reduce0_ptrs) - { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_0 in_elementwise_op_0(static_cast(reduce_total_length)); - AccElementwiseOperation_0 acc_elementwise_op_0(static_cast(reduce_total_length)); - - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_0, - acc_elementwise_op_0); - - if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) - continue; - - auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - - (void)invoker_ptr->Run(argument_ptr.get()); - - out_dev.FromDevice(out.mData.data()); - - bool single_result = true; - - if constexpr(std::is_same::value || - std::is_same::value) - { - reduce_util::to_f32_vector(out, out_fp32); - reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = ck::utils::check_err( - out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); - } - else - { - single_result = - ck::utils::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); - }; - - if(NeedIndices) - { - out_indices_dev.FromDevice(out_indices.mData.data()); - single_result = single_result && ck::utils::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); - }; - - if(!single_result) - { - std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << std::endl; - result = false; - } - }; - - for(auto& reduce_ptr : reduce1_ptrs) - { - auto wsSizeInBytes = reduce_ptr->GetWorkspaceSizeInBytes(i_inLengths, reduceDims); - - DeviceMem ws_dev(wsSizeInBytes); - - InElementwiseOperation_1 in_elementwise_op_1(static_cast(reduce_total_length)); - AccElementwiseOperation_1 acc_elementwise_op_1(static_cast(reduce_total_length)); - - auto argument_ptr = reduce_ptr->MakeArgumentPointer(i_inLengths, - i_inStrides, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - in_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_1, - acc_elementwise_op_1); - - if(!reduce_ptr->IsSupportedArgument(argument_ptr.get())) - continue; - - std::string reduce_name = reduce_ptr->GetTypeString(); - - auto invoker_ptr = reduce_ptr->MakeInvokerPointer(); - - (void)invoker_ptr->Run(argument_ptr.get()); - - std::vector inLengths2 = reduce_ptr->GetWorkspace2dLengths(argument_ptr.get()); - std::vector inStrides2{inLengths2[1], 1}; - - for(auto& reduce2_ptr : reduce2_ptrs) - { - InElementwiseOperation_2 in_elementwise_op_2(static_cast(reduce_total_length)); - AccElementwiseOperation_2 acc_elementwise_op_2( - static_cast(reduce_total_length)); - - auto argument2_ptr = reduce2_ptr->MakeArgumentPointer(inLengths2, - inStrides2, - i_outLengths, - i_outStrides, - reduceDims, - alpha, - beta, - ws_dev.GetDeviceBuffer(), - out_dev.GetDeviceBuffer(), - out_indices_dev.GetDeviceBuffer(), - ws_dev.GetDeviceBuffer(), - in_elementwise_op_2, - acc_elementwise_op_2); - - if(!reduce2_ptr->IsSupportedArgument(argument2_ptr.get())) - continue; - - std::string reduce2_name = reduce2_ptr->GetTypeString(); - - auto invoker2_ptr = reduce2_ptr->MakeInvokerPointer(); - - (void)invoker2_ptr->Run(argument2_ptr.get()); - - out_dev.FromDevice(out.mData.data()); - - bool single_result = true; - - if constexpr(std::is_same::value || - std::is_same::value) - { - reduce_util::to_f32_vector(out, out_fp32); - reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = ck::utils::check_err( - out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); - } - else - { - single_result = - ck::utils::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); - }; - - if(NeedIndices) - { - out_indices_dev.FromDevice(out_indices.mData.data()); - single_result = - single_result && ck::utils::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); - }; - - if(!single_result) - { - std::cout << "Fail Info: " << reduce_ptr->GetTypeString() << " => " - << reduce2_ptr->GetTypeString() << std::endl; - result = false; - } - }; - }; - - return (result); -}; - -} // anonymous namespace - static struct option long_options[] = {{"inLengths", required_argument, nullptr, 'D'}, {"reduceDimensions", required_argument, nullptr, 'R'}, {"scales", required_argument, nullptr, 'S'}, @@ -390,48 +13,6 @@ static struct option long_options[] = {{"inLengths", required_argument, nullptr, class SimpleAppArgs { - template - static T getSingleValueFromString(const std::string& valueStr) - { - std::istringstream iss(valueStr); - - T ret; - - iss >> ret; - - return (ret); - }; - - template - static std::vector getTypeValuesFromString(const char* cstr_values) - { - std::string valuesStr(cstr_values); - - std::vector values; - std::size_t pos = 0; - std::size_t new_pos; - - new_pos = valuesStr.find(',', pos); - while(new_pos != std::string::npos) - { - const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); - - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - pos = new_pos + 1; - new_pos = valuesStr.find(',', pos); - }; - - std::string sliceStr = valuesStr.substr(pos); - T val = getSingleValueFromString(sliceStr); - - values.push_back(val); - - return (values); - }; - private: int option_index = 0; @@ -463,7 +44,9 @@ class SimpleAppArgs int processArgs(int argc, char* argv[]) { - unsigned int ch; + using ck::host_common::getTypeValuesFromString; + + int ch; while(1) { @@ -517,7 +100,7 @@ class SimpleAppArgs (reduceDims.size() != 1 && reduceDims.size() != 3 && reduceDims.size() != 4)) return (-1); - if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5) + if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6) return (-1); return (0); @@ -528,87 +111,92 @@ bool test_reduce_with_index(int data_type, int init_method, std::vector reduceDims, std::vector inLengths, + ReduceTensorOp reduceOpId, + bool propagateNan, float alpha, float beta) { + using ck::profiler::profile_reduce_impl; + bool result = true; if(data_type == 0) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); } else if(data_type == 1) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); } else if(data_type == 3) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); } else if(data_type == 5) { - switch(reduceDims.size()) - { - case 1: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 3: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - case 4: - result = test_reduce_with_index_impl( - init_method, inLengths, reduceDims, alpha, beta); - break; - }; + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); + } + else if(data_type == 6) + { + result = profile_reduce_impl(true, + init_method, + false, + false, + inLengths, + reduceDims, + reduceOpId, + propagateNan, + true, + alpha, + beta); } return (result); }; +constexpr ReduceTensorOp reduceOpId = ReduceTensorOp::AMAX; +constexpr bool propagateNan = false; + int main(int argc, char* argv[]) { SimpleAppArgs args; @@ -624,8 +212,14 @@ int main(int argc, char* argv[]) {0, 1, 2, 3}, {0, 1, 2}, {1, 2, 3}, {0, 1, 3}, {0, 2, 3}, {0}, {1}, {2}, {3}}; for(auto& reduceDims : v_reduceDims) - result = result && test_reduce_with_index( - data_type, init_method, reduceDims, inLengths, 1.0f, 0.0f); + result = result && test_reduce_with_index(data_type, + init_method, + reduceDims, + inLengths, + reduceOpId, + propagateNan, + 1.0f, + 0.0f); } else { @@ -639,6 +233,8 @@ int main(int argc, char* argv[]) args.init_method, args.reduceDims, args.inLengths, + reduceOpId, + propagateNan, args.scales[0], args.scales[1]); } diff --git a/test/reference_conv_fwd/CMakeLists.txt b/test/reference_conv_fwd/CMakeLists.txt index e5a7b31aff..04b720b169 100644 --- a/test/reference_conv_fwd/CMakeLists.txt +++ b/test/reference_conv_fwd/CMakeLists.txt @@ -1,2 +1,2 @@ add_gtest_executable(test_reference_conv_fwd reference_conv_fwd.cpp) -target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor conv_fwd_util) +target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor conv_util) diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index f660559e62..69b223989f 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -8,7 +8,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "element_wise_operation.hpp" #include "fill.hpp" #include "host_tensor.hpp" @@ -34,21 +34,21 @@ run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, const FillInputOp& fill_input_op = FillInputOp{}, const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) { - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; + std::vector input_dims{static_cast(params.N_), + static_cast(params.C_)}; input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); + std::begin(params.input_spatial_lengths_), + std::end(params.input_spatial_lengths_)); - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; + std::vector filter_dims{static_cast(params.K_), + static_cast(params.C_)}; filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); + std::begin(params.filter_spatial_lengths_), + std::end(params.filter_spatial_lengths_)); const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; + std::vector output_dims{static_cast(params.N_), + static_cast(params.K_)}; output_dims.insert(std::end(output_dims), std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); @@ -74,10 +74,10 @@ run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, auto ref_argument = ref_conv.MakeArgument(input, weights, host_output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -91,15 +91,15 @@ run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, TEST(ReferenceConvolutionFWD, Conv2DNHWC) { ck::utils::conv::ConvParams params; - params.N = 1; - params.K = 1; - params.C = 2; - params.filter_spatial_lengths = std::vector{3, 3}; - params.input_spatial_lengths = std::vector{6, 6}; - params.conv_filter_strides = std::vector{1, 1}; - params.conv_filter_dilations = std::vector{1, 1}; - params.input_left_pads = std::vector{0, 0}; - params.input_right_pads = std::vector{0, 0}; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6}; + params.conv_filter_strides_ = std::vector{1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1}; + params.input_left_pads_ = std::vector{0, 0}; + params.input_right_pads_ = std::vector{0, 0}; auto out_tensor = run_reference_convolution_forward<2>(params); std::vector ref_dims{1, 1, 4, 4}; @@ -127,15 +127,15 @@ TEST(ReferenceConvolutionFWD, Conv2DNHWC) TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) { ck::utils::conv::ConvParams params; - params.N = 1; - params.K = 2; - params.C = 2; - params.filter_spatial_lengths = std::vector{3, 3}; - params.input_spatial_lengths = std::vector{12, 12}; - params.conv_filter_strides = std::vector{2, 2}; - params.conv_filter_dilations = std::vector{2, 2}; - params.input_left_pads = std::vector{1, 1}; - params.input_right_pads = std::vector{1, 1}; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12}; + params.conv_filter_strides_ = std::vector{2, 2}; + params.conv_filter_dilations_ = std::vector{2, 2}; + params.input_left_pads_ = std::vector{1, 1}; + params.input_right_pads_ = std::vector{1, 1}; auto out_tensor = run_reference_convolution_forward<2>(params); std::vector ref_dims = std::vector{1, 2, 5, 5}; @@ -153,16 +153,16 @@ TEST(ReferenceConvolutionFWD, Conv2DNHWCStridesDilationsPadding) TEST(ReferenceConvolutionFWD, Conv1DNWC) { ck::utils::conv::ConvParams params; - params.num_dim_spatial = 1; - params.N = 1; - params.K = 1; - params.C = 2; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{6}; - params.conv_filter_strides = std::vector{1}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{0}; - params.input_right_pads = std::vector{0}; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{6}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{0}; + params.input_right_pads_ = std::vector{0}; auto out_tensor = run_reference_convolution_forward<1, @@ -182,16 +182,16 @@ TEST(ReferenceConvolutionFWD, Conv1DNWC) TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) { ck::utils::conv::ConvParams params; - params.num_dim_spatial = 1; - params.N = 1; - params.K = 2; - params.C = 2; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{12}; - params.conv_filter_strides = std::vector{2}; - params.conv_filter_dilations = std::vector{2}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; + params.num_dim_spatial_ = 1; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{12}; + params.conv_filter_strides_ = std::vector{2}; + params.conv_filter_dilations_ = std::vector{2}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; auto out_tensor = run_reference_convolution_forward<1, @@ -211,16 +211,16 @@ TEST(ReferenceConvolutionFWD, Conv1DNWCStridesDilationsPadding) TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) { ck::utils::conv::ConvParams params; - params.num_dim_spatial = 1; - params.N = 2; - params.K = 16; - params.C = 4; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{16}; - params.conv_filter_strides = std::vector{1}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; + params.num_dim_spatial_ = 1; + params.N_ = 2; + params.K_ = 16; + params.C_ = 4; + params.filter_spatial_lengths_ = std::vector{3}; + params.input_spatial_lengths_ = std::vector{16}; + params.conv_filter_strides_ = std::vector{1}; + params.conv_filter_dilations_ = std::vector{1}; + params.input_left_pads_ = std::vector{1}; + params.input_right_pads_ = std::vector{1}; auto out_tensor2 = run_reference_convolution_forward<1, float, @@ -305,16 +305,16 @@ TEST(ReferenceConvolutionFWD, Conv1DNWCSameOutputSize) TEST(ReferenceConvolutionFWD, Conv3DNCDHW) { ck::utils::conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 1; - params.K = 1; - params.C = 2; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{6, 6, 6}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{0, 0, 0}; - params.input_right_pads = std::vector{0, 0, 0}; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 1; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{6, 6, 6}; + params.conv_filter_strides_ = std::vector{1, 1, 1}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; auto out_tensor = run_reference_convolution_forward<3, float, @@ -344,16 +344,16 @@ TEST(ReferenceConvolutionFWD, Conv3DNCDHW) TEST(ReferenceConvolutionFWD, Conv3DNCDHWStridesDilations) { ck::utils::conv::ConvParams params; - params.num_dim_spatial = 3; - params.N = 1; - params.K = 2; - params.C = 2; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{12, 12, 12}; - params.conv_filter_strides = std::vector{3, 3, 3}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{0, 0, 0}; - params.input_right_pads = std::vector{0, 0, 0}; + params.num_dim_spatial_ = 3; + params.N_ = 1; + params.K_ = 2; + params.C_ = 2; + params.filter_spatial_lengths_ = std::vector{3, 3, 3}; + params.input_spatial_lengths_ = std::vector{12, 12, 12}; + params.conv_filter_strides_ = std::vector{3, 3, 3}; + params.conv_filter_dilations_ = std::vector{1, 1, 1}; + params.input_left_pads_ = std::vector{0, 0, 0}; + params.input_right_pads_ = std::vector{0, 0, 0}; auto out_tensor = run_reference_convolution_forward<3, float, diff --git a/test/softmax/CMakeLists.txt b/test/softmax/CMakeLists.txt new file mode 100644 index 0000000000..50ec04f9e4 --- /dev/null +++ b/test/softmax/CMakeLists.txt @@ -0,0 +1,8 @@ +add_custom_target(test_softmax) + +add_gtest_executable(test_softmax_fp32 test_softmax_fp32.cpp) +add_gtest_executable(test_softmax_fp16 test_softmax_fp16.cpp) +target_link_libraries(test_softmax_fp32 PRIVATE host_tensor) +target_link_libraries(test_softmax_fp16 PRIVATE host_tensor) +add_dependencies(test_softmax test_softmax_fp32) +add_dependencies(test_softmax test_softmax_fp16) \ No newline at end of file diff --git a/test/softmax/test_softmax_fp16.cpp b/test/softmax/test_softmax_fp16.cpp new file mode 100644 index 0000000000..9ea204a5ee --- /dev/null +++ b/test/softmax/test_softmax_fp16.cpp @@ -0,0 +1,26 @@ +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP16 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP16, KernelTypes); +TYPED_TEST(TestSoftmaxFP16, Test_FP16) { this->Run(); } diff --git a/test/softmax/test_softmax_fp32.cpp b/test/softmax/test_softmax_fp32.cpp new file mode 100644 index 0000000000..a7f6cf6b5d --- /dev/null +++ b/test/softmax/test_softmax_fp32.cpp @@ -0,0 +1,26 @@ +#include "gtest/gtest.h" +#include "test_softmax_util.hpp" + +template +using I = ck::Number; + +template +class TestSoftmaxFP32 : public ck::TestSoftmax +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< +// InDataType, AccDataType, OutDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize> + std::tuple, I<1>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<1>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<8>, I<32>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<4>, I<64>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<2>, I<128>, I<1>, I<4>, I<1>, I<4>, I<4>>, + std::tuple, I<2>, I<256>, I<1>, I<256>, I<1>, I<4>, I<1>, I<4>, I<4>> + >; +// clang-format on +TYPED_TEST_SUITE(TestSoftmaxFP32, KernelTypes); +TYPED_TEST(TestSoftmaxFP32, Test_FP32) { this->Run(); } diff --git a/test/softmax/test_softmax_util.hpp b/test/softmax/test_softmax_util.hpp new file mode 100644 index 0000000000..39182c3c11 --- /dev/null +++ b/test/softmax/test_softmax_util.hpp @@ -0,0 +1,113 @@ +#include +#include +#include "gtest/gtest.h" + +#include "config.hpp" +#include "host_tensor.hpp" +#include "check_err.hpp" +#include "number.hpp" +#include "reference_softmax.hpp" +#include "device_softmax.hpp" + +namespace ck { + +template +class TestSoftmax : public ::testing::Test +{ + protected: + using InDataType = std::tuple_element_t<0, Tuple>; + using AccDataType = std::tuple_element_t<1, Tuple>; + using OutDataType = std::tuple_element_t<2, Tuple>; + static constexpr index_t Rank = std::tuple_element_t<3, Tuple>{}.value; + static constexpr index_t NumReduceDim = std::tuple_element_t<4, Tuple>{}.value; + static constexpr index_t BlockSize = std::tuple_element_t<5, Tuple>{}.value; + static constexpr index_t MThreadClusterSize = std::tuple_element_t<6, Tuple>{}.value; + static constexpr index_t KThreadClusterSize = std::tuple_element_t<7, Tuple>{}.value; + static constexpr index_t MThreadSliceSize = std::tuple_element_t<8, Tuple>{}.value; + static constexpr index_t KThreadSliceSize = std::tuple_element_t<9, Tuple>{}.value; + static constexpr index_t InSrcVectorDim = std::tuple_element_t<10, Tuple>{}.value; + static constexpr index_t InSrcVectorSize = std::tuple_element_t<11, Tuple>{}.value; + static constexpr index_t OutDstVectorSize = std::tuple_element_t<12, Tuple>{}.value; + + using ReferenceInstance = + tensor_operation::host::ReferenceSoftmax; + + using DeviceInstance = tensor_operation::device::DeviceSoftmax; + + TestSoftmax() : ref_instance_invoker_(ReferenceInstance{}.MakeInvoker()) {} + + void RunSingle(std::vector in_length, AccDataType alpha, AccDataType beta) + { + std::vector reduce_dims(NumReduceDim); + std::iota(reduce_dims.begin(), reduce_dims.end(), Rank - NumReduceDim); + + Tensor in(in_length); + Tensor out(in_length); + + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + Tensor out_ref(out); + + DeviceMem in_dev(sizeof(InDataType) * in.mDesc.GetElementSpace()); + DeviceMem out_dev(sizeof(OutDataType) * out.mDesc.GetElementSpace()); + in_dev.ToDevice(in.mData.data()); + out_dev.ToDevice(out.mData.data()); + + std::vector i_in_lengths(in.mDesc.GetLengths().begin(), + in.mDesc.GetLengths().end()); + std::vector i_in_strides(in.mDesc.GetStrides().begin(), + in.mDesc.GetStrides().end()); + + auto device_instance = DeviceInstance{}; + auto argument_ptr = device_instance.MakeArgumentPointer(i_in_lengths, + i_in_strides, + reduce_dims, + alpha, + beta, + in_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer()); + + if(!device_instance.IsSupportedArgument(argument_ptr.get())) + { + FAIL() << "Unsupported argument"; + } + + auto invoker_ptr = device_instance.MakeInvokerPointer(); + invoker_ptr->Run(argument_ptr.get()); + + ref_instance_invoker_.Run({in, out_ref, alpha, beta, Rank, reduce_dims}); + + out_dev.FromDevice(out.mData.data()); + EXPECT_TRUE(ck::utils::check_err(out.mData, out_ref.mData)); + } + + void Run() + { + for(auto in_length : this->in_lengths_) + { + for(auto scale : this->scales_) + { + this->RunSingle(in_length, std::get<0>(scale), std::get<1>(scale)); + } + } + } + + std::vector> in_lengths_ = {{1, 8, 128}, {2, 128, 1024}, {3, 9, 1032}}; + std::vector> scales_ = {{1, 0}, {2, 2}, {0, 1}}; + + typename ReferenceInstance::Invoker ref_instance_invoker_; +}; +} // namespace ck