diff --git a/.gitignore b/.gitignore index 5667695bb5..362fb9e2ef 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,11 @@ build* .gdb_history install.dir* -# directories containing generated documentation -docs/source/_build/ -docs/docBin/ +# documentation artifacts +build/ +_build/ +_images/ +_static/ +_templates/ +_toc.yml +docBin/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000000..b739536837 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,18 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.8" + +sphinx: + configuration: docs/conf.py + +formats: [htmlzip] + +python: + install: + - requirements: docs/.sphinx/requirements.txt diff --git a/CHANGELOG.md b/CHANGELOG.md index 79c45a0db3..0188350046 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Full documentation for Composable Kernel is not yet available. -## CK 0.1.1 for ROCm 5.5.0 +## CK 0.2.0 for ROCm 5.5.0 ### Fixed - Fixed a bug in 6-dimensional kernels (#555). @@ -12,6 +12,7 @@ Full documentation for Composable Kernel is not yet available. - Improve proformance of normalization kernel ### Added +- Added support on NAVI3x. - Added user tutorial (#563). - Added more instances for irregular GEMM sizes (#560). - Added inter-wave consumer-producer programming model for GEMM kernels (#310). diff --git a/CMakeLists.txt b/CMakeLists.txt index f861e30203..c9fb6b4552 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,7 @@ include(TargetFlags) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF) +option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF) if(USE_BITINT_EXTENSION_INT4) add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) @@ -29,6 +30,12 @@ if(USE_BITINT_EXTENSION_INT4) message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}") endif() +if(USE_OPT_NAVI3X) + add_compile_options(-mcumode) + add_compile_options(-mno-wavefrontsize64) + message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}") +endif() + ## Threads set(THREADS_PREFER_PTHREAD_FLAG ON) find_package(Threads REQUIRED) diff --git a/Dockerfile b/Dockerfile index dd2a97c7bd..b03cb836ad 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,8 @@ ARG compiler_commit="" RUN set -xe ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins +RUN useradd -rm -d /home/manitera -s /bin/bash -u 1002 manitera # Add rocm repository RUN apt-get update RUN apt-get install -y wget gnupg @@ -37,6 +39,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python-dev \ python3-dev \ python3-pip \ + sshpass \ software-properties-common \ rocm-dev \ rocm-device-libs \ diff --git a/Jenkinsfile b/Jenkinsfile index 1dbb0a4d9e..edc8d47e4b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -14,7 +14,6 @@ def show_node_info() { def runShell(String command){ def responseCode = sh returnStatus: true, script: "${command} > tmp.txt" def output = readFile(file: "tmp.txt") - echo "tmp.txt contents: $output" return (output != "") } @@ -172,7 +171,7 @@ def cmake_build(Map conf=[:]){ if(conf.get("build_install","") == "true") { config_targets = 'install ' + config_targets - setup_args = ' -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install' + setup_args + setup_args = ' -DBUILD_DEV=On -DCMAKE_INSTALL_PREFIX=../install' + setup_args } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } @@ -427,6 +426,7 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage + def navi_node = 0 gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCmSoftwarePlatform', repo: 'composable_kernel-internal') { try { @@ -440,6 +440,9 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } + if ( runShell('grep -n "gfx1030" clinfo.log') ){ + navi_node = 1 + } } } } @@ -458,6 +461,9 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } + if ( runShell('grep -n "gfx1030" clinfo.log') ){ + navi_node = 1 + } } } } @@ -466,16 +472,20 @@ def Build_CK(Map conf=[:]){ { cmake_build(conf) dir("build"){ - //run tests and examples - sh 'make -j check' - //we only need the ckProfiler to run the performance tests, so we pack and stash it - sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash "ckProfiler.tar.gz" + if (navi_node == 0 ){ + //run tests and examples on all nodes except Navi + sh 'make -j check' + //we only need the ckProfiler to run the performance tests, so we pack and stash it + sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' + stash "ckProfiler.tar.gz" + } if (params.RUN_FULL_QA){ // build deb packages sh 'make -j package' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb' + sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' + stash "ckprofiler_0.2.0_amd64.deb" } } } @@ -543,6 +553,8 @@ def process_results(Map conf=[:]){ unstash "perf_splitK_gemm.log" unstash "perf_onnx_gemm.log" sh "./process_qa_data.sh" + unstash "ckprofiler_0.2.0_amd64.deb" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master @@ -645,12 +657,28 @@ pipeline { { parallel { - stage("Build CK and run Tests") + stage("Build CK and run Tests on MI100/MI200") { agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" """ : """ -DBUILD_DEV=Off -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 " """ }" - execute_args = "${params.COMPILER_VERSION == "ck-9110" ? """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ : """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a;gfx1030" -DCMAKE_CXX_FLAGS="-O3" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ }" + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908,gfx90a" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + } + } + stage("Build CK and run Tests on Navi") + { + when { + beforeAgent true + expression { !params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("navi21") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030;gfx1100;gfx1101;gfx1102" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ + } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -663,7 +691,7 @@ pipeline { { parallel { - stage("Run ckProfiler: gfx908 or gfx90a") + stage("Run ckProfiler: gfx90*") { when { beforeAgent true @@ -672,7 +700,7 @@ pipeline { options { retry(2) } agent{ label rocmnode("gfx908 || gfx90a")} environment{ - setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx908;gfx90a;gfx1030" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" + setup_args = """ -DGPU_TARGETS="gfx908;gfx90a" -DBUILD_DEV=On """ } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') @@ -687,7 +715,7 @@ pipeline { options { retry(2) } agent{ label rocmnode("gfx90a")} environment{ - setup_args = "${params.COMPILER_VERSION == "ck-9110" ? """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 -Xclang -mlink-builtin-bitcode -Xclang /opt/rocm/amdgcn/bitcode/oclc_abi_version_400.bc" -DBUILD_DEV=On """ : """ -DGPU_TARGETS="gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " -DBUILD_DEV=On """}" + setup_args = """ -DGPU_TARGETS="gfx90a" -DBUILD_DEV=On """ } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') diff --git a/README.md b/README.md index 151da974a5..04199f11bf 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ CK utilizes two concepts to achieve performance portability and code maintainabi * A tile-based programming model * Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". -![ALT](/doc/image/ck_component.png "CK Components") +![ALT](/docs/data/ck_component.png "CK Components") ## Code Structure Current CK library are structured into 4 layers: @@ -16,7 +16,17 @@ Current CK library are structured into 4 layers: * "Instantiated Kernel and Invoker" layer * "Client API" layer -![ALT](/doc/image/ck_layer.png "CK Layers") +![ALT](/docs/data/ck_layer.png "CK Layers") + +## Documentation + +Run the steps below to build documentation locally. + +``` +cd docs +pip3 install -r .sphinx/requirements.txt +python3 -m sphinx -T -E -b html -d _build/doctrees -D language=en . _build/html +``` ## Contributors The list of developers and contributors is here: [Contributors](/CONTRIBUTORS.md) diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index 971d5d9f1c..7ffedfeef3 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,8 +1,14 @@ -add_executable(client_contraction_scale contraction_scale.cpp) -target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device_operations) +add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) +target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_operations) -add_executable(client_contraction_bilinear contraction_bilinear.cpp) -target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) +add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) +target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) +target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_operations) + +add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) +target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_operations) add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations) diff --git a/client_example/04_contraction/contraction_bilinear.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp similarity index 100% rename from client_example/04_contraction/contraction_bilinear.cpp rename to client_example/04_contraction/contraction_bilinear_fp32.cpp diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp new file mode 100644 index 0000000000..9238e4cd80 --- /dev/null +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp" +#include "ck/library/utility/numeric.hpp" + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Bilinear; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ +// kknn +#if 1 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// knnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mknn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mnnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +#endif + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 25) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21]), std::stoi(argv[22])}; + + alpha = std::stof(argv[23]); + beta = std::stof(argv[24]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg19 to 22: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg23 to 24: alpha, beta\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem d_device_buf(sizeof(DDataType) * + f_tensor_space_size(d_ms_ns_lengths, d_ms_ns_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{alpha, beta}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/04_contraction/contraction_scale.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp similarity index 100% rename from client_example/04_contraction/contraction_scale.cpp rename to client_example/04_contraction/contraction_scale_fp32.cpp diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp new file mode 100644 index 0000000000..3c36aa21eb --- /dev/null +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -0,0 +1,270 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/contraction_scale.hpp" +#include "ck/library/utility/numeric.hpp" + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = Scale; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DsDataType = ck::Tuple<>; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ +// kkn +#if 1 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// knn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mkn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +// mnn +#elif 0 + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{128, 1, 245760, 3840}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{64, 1, 131072, 2048}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; +#endif + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 20) + { + const ck::index_t M0 = std::stoi(argv[1]); + const ck::index_t M1 = std::stoi(argv[2]); + + const ck::index_t N0 = std::stoi(argv[3]); + const ck::index_t N1 = std::stoi(argv[4]); + + const ck::index_t K0 = std::stoi(argv[5]); + const ck::index_t K1 = std::stoi(argv[6]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[7]), std::stoi(argv[8]), std::stoi(argv[9]), std::stoi(argv[10])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13]), std::stoi(argv[14])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17]), std::stoi(argv[18])}; + + scale = std::stof(argv[19]); + } + else + { + printf("arg1 to 6: M0, M1, N0, N1, K0, K1\n"); + printf("arg7 to 10: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg11 to 14: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg15 to 18: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg19: scale\n"); + exit(0); + } + + auto f_tensor_space_size = [](auto lengths, auto strides) { + std::size_t space_size = 1; + for(std::size_t i = 0; i < lengths.size(); ++i) + { + space_size += (lengths[i] - 1) * strides[i]; + } + return space_size; + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * + f_tensor_space_size(a_ms_ks_lengths, a_ms_ks_strides)); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * + f_tensor_space_size(b_ns_ks_lengths, b_ns_ks_strides)); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * + f_tensor_space_size(e_ms_ns_lengths, e_ms_ns_strides)); + + using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD< + NumDimM, + NumDimN, + NumDimK, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{scale}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + ck::index_t M = ck::accumulate_n( + e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 7dc9b860c0..2b7d6fc806 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,6 +1,12 @@ +add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_operations) +add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) +target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_operations) @@ -9,3 +15,6 @@ target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composab add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_quantization gemm_quantization.cpp) +target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations) diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index bcb0cefa71..cf6807f0dd 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -28,16 +28,15 @@ using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Cla static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; -static constexpr ck::index_t K = 64; -static constexpr ck::index_t C = 32; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 71; -static constexpr ck::index_t Wi = 71; -static constexpr ck::index_t Ho = 36; -static constexpr ck::index_t Wo = 36; - +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -64,8 +63,8 @@ int main(int argc, char* argv[]) std::array bias_strides{K, 0, 1, 0, 0}; std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; - std::array out_lengths{G, N, C, Ho, Wo}; - std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; @@ -136,10 +135,11 @@ int main(int argc, char* argv[]) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + - G * sizeof(WeiDataType) * K * Y * X * C + - G * sizeof(OutDataType) * N * Ho * Wo * K; + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(RequantScaleDataType) * K + + G * sizeof(OutDataType) * N * Ho * Wo * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -162,11 +162,12 @@ int main(int argc, char* argv[]) } } - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - // run the best intance + if(best_op_id != -1) { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index 26c7aa15e2..b8e6a493ef 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -26,15 +26,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clam static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; -static constexpr ck::index_t K = 64; -static constexpr ck::index_t C = 32; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 71; -static constexpr ck::index_t Wi = 71; -static constexpr ck::index_t Ho = 36; -static constexpr ck::index_t Wo = 36; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float requant_scale = 0.5f; // requantize qAcc to qz struct SimpleDeviceMem { @@ -60,8 +61,8 @@ int main(int argc, char* argv[]) std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; - std::array out_lengths{G, N, C, Ho, Wo}; - std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; @@ -102,26 +103,27 @@ int main(int argc, char* argv[]) for(int i = 0; i < op_ptrs.size(); ++i) { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {bias.GetDeviceBuffer()}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {bias_lengths}, - {bias_strides}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); @@ -130,10 +132,10 @@ int main(int argc, char* argv[]) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + - G * sizeof(WeiDataType) * K * Y * X * C + - G * sizeof(OutDataType) * N * Ho * Wo * K; + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(OutDataType) * N * Ho * Wo * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -156,33 +158,35 @@ int main(int argc, char* argv[]) } } - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - // run the best intance + if(best_op_id != -1) { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {bias.GetDeviceBuffer()}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {bias_lengths}, - {bias_strides}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp new file mode 100644 index 0000000000..7a216f027f --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -0,0 +1,209 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using RequantScaleLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float sz_inv = 0.5f; // inverse of scale_z + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array requant_scale_lengths{G, N, K, Ho, Wo}; + std::array requant_scale_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< + NumDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(RequantScaleDataType) * K + + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(best_op_id != -1) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer(), requant_scale.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths, requant_scale_lengths}, + {bias_strides, requant_scale_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp new file mode 100644 index 0000000000..7637f5c785 --- /dev/null +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using OutDataType = int8_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using BiasLayout = ck::tensor_layout::convolution::G_K; +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W +static constexpr float sacc = 0.5f; // scale of acc +static constexpr float sz_inv = 0.5f; // inverse of scale_z + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array weight_lengths{G, K, C, Y, X}; + std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array bias_lengths{G, N, K, Ho, Wo}; + std::array bias_strides{K, 0, 1, 0, 0}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array in_left_pad{1, 1}; + std::array in_right_pad{1, 1}; + std::array conv_strides{2, 2}; + std::array conv_dilations{1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + OutElementOp>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sacc, sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(BiasDataType) * K + G * sizeof(OutDataType) * N * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + // run the best intance + if(best_op_id != -1) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {bias.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {bias_lengths}, + {bias_strides}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{sacc, sz_inv, ActivationOp{}}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index 475b2f03b4..c1c5a651eb 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -26,15 +26,15 @@ using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_C static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; -static constexpr ck::index_t K = 64; -static constexpr ck::index_t C = 32; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 71; -static constexpr ck::index_t Wi = 71; -static constexpr ck::index_t Ho = 36; -static constexpr ck::index_t Wo = 36; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { @@ -60,8 +60,8 @@ int main(int argc, char* argv[]) std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; - std::array out_lengths{G, N, C, Ho, Wo}; - std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; @@ -130,10 +130,10 @@ int main(int argc, char* argv[]) { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = G * sizeof(InDataType) * N * Hi * Wi * C + - G * sizeof(WeiDataType) * K * Y * X * C + - G * sizeof(OutDataType) * N * Ho * Wo * K; + std::size_t flop = G * 2 * N * K * C * Ho * Wo * Y * X; + std::size_t num_bytes = + G * sizeof(InDataType) * N * Hi * Wi * C + G * sizeof(WeiDataType) * K * Y * X * C + + G * sizeof(RequantScaleDataType) * K + G * sizeof(OutDataType) * N * Ho * Wo * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -156,11 +156,12 @@ int main(int argc, char* argv[]) } } - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - // run the best intance + if(best_op_id != -1) { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index da7b7e6abf..f7c46a95fe 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -24,15 +24,16 @@ using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; - std::array out_lengths{G, N, C, Ho, Wo}; - std::array out_strides{N * Ho * Wo * C, Ho * Wo * C, 1, Wo * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; @@ -96,26 +97,27 @@ int main(int argc, char* argv[]) for(int i = 0; i < op_ptrs.size(); ++i) { - auto& op_ptr = op_ptrs[i]; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {}, - {}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); @@ -150,33 +152,34 @@ int main(int argc, char* argv[]) } } - std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops - << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; - - // run the best intance + if(best_op_id != -1) { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - {}, - out.GetDeviceBuffer(), - in_lengths, - in_strides, - weight_lengths, - weight_strides, - {}, - {}, - out_lengths, - out_strides, - conv_strides, - conv_dilations, - in_left_pad, - in_right_pad, - PassThrough{}, - PassThrough{}, - OutElementOp{0.5f, ActivationOp{}}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + weight_lengths, + weight_strides, + {}, + {}, + out_lengths, + out_strides, + conv_strides, + conv_dilations, + in_left_pad, + in_right_pad, + PassThrough{}, + PassThrough{}, + OutElementOp{requant_scale, ActivationOp{}}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/09_quantization/gemm_quantization.cpp b/client_example/09_quantization/gemm_quantization.cpp new file mode 100644 index 0000000000..242504b44f --- /dev/null +++ b/client_example/09_quantization/gemm_quantization.cpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = int8_t; +using BDataType = int8_t; +using EDataType = int8_t; + +using ALayout = Row; +using BLayout = Col; +using ELayout = Row; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem e_device_buf(sizeof(EDataType) * f_matrix_space_size(M, N, StrideE, ELayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + AElementOp, + BElementOp, + CDEElementOp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_bytes = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id != -1) + { + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000..659e6769d8 --- /dev/null +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -0,0 +1,2 @@ +add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) +target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_operations) \ No newline at end of file diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp new file mode 100644 index 0000000000..223ed29be9 --- /dev/null +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using ADataType = F16; +using BDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = FastGelu; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::mt19937 gen(19391); + std::uniform_int_distribution<> distrib(1, 10); + int group_count = distrib(gen); + + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * distrib(gen)); + Ns.push_back(256 + 256 * distrib(gen)); + Ks.push_back(128 + 128 * distrib(gen)); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_e; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideEs[i], {}}); + + p_a.push_back(a_dev_bufs[i].GetDeviceBuffer()); + p_b.push_back(b_dev_bufs[i].GetDeviceBuffer()); + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto p_ds = std::vector>{}; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + if(found) + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 78133af031..87bcb08e83 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -65,7 +65,8 @@ else() -Wuninitialized -Wunreachable-code -Wunused - + -Wno-reserved-identifier + -Werror -Wsign-compare -Wno-extra-semi-stmt ) diff --git a/doc/markdown/dockerhub.md b/doc/markdown/dockerhub.md deleted file mode 100644 index 91b6cb2295..0000000000 --- a/doc/markdown/dockerhub.md +++ /dev/null @@ -1,93 +0,0 @@ -## CK docker hub - -[Docker hub](https://hub.docker.com/r/rocm/composable_kernel) - -## Why do I need this? - -To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images. - -## So what is Composable Kernel? - -Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. - -To get the CK library - -``` -git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git -``` - -run a docker container - -``` -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ -/bin/bash -``` - -and build the CK - -``` -mkdir build && cd build - -# Need to specify target ID, example below is for gfx908 and gfx90a -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D GPU_TARGETS="gfx908;gfx90a" \ -.. -``` - -and - -``` -make -j examples tests -``` - -To run all the test cases including tests and examples run - -``` -make test -``` - -We can also run specific examples or tests like - -``` -./bin/example_gemm_xdl_fp16 -./bin/test_gemm_fp16 -``` - -For more details visit [CK github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel), [CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example), [even more CK examples](https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example). - -## And what is inside? - -The docker images have everything you need for running CK including: - -* [ROCm](https://www.amd.com/en/graphics/servers-solutions-rocm) -* [CMake](https://cmake.org/) -* [Compiler](https://github.com/RadeonOpenCompute/llvm-project) - -## Which image is right for me? - -Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are: - -* "ck" - made for running Composable Kernel -* "ub20.04" - based on Ubuntu 20.04 -* "rocm5.4" - ROCm platform version 5.4 -* "release" - compiler version is release - -So just pick the right image for your project dependencies and you're all set. - -## DIY starts here - -If you need to customize a docker image or just can't stop tinkering, feel free to adjust the [Dockerfile](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile) for your needs. - -## License - -CK is released under the MIT [license](https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/LICENSE). diff --git a/doc/markdown/tutorial_hello_world.md b/doc/markdown/tutorial_hello_world.md deleted file mode 100644 index 297df10b5c..0000000000 --- a/doc/markdown/tutorial_hello_world.md +++ /dev/null @@ -1,191 +0,0 @@ -## CK Hello world - -## Motivation - -This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever. - -During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project. - -## Description - -Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more. - -So how do we (almost) reach the speed of light? CK acceleration abilities are based on: - -* Layered structure. -* Tile-based computation model. -* Tensor coordinate transformation. -* Hardware acceleration use. -* Support of low precision data types including fp16, bf16, int8 and int4. - -If you are excited and need more technical details and benchmarking results - read this awesome blog [post](https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224). - -For more details visit our [github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel). - -## Hardware targets - -CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture - -GPU Target AMD GPU -gfx908 Radeon Instinct MI100 -gfx90a Radeon Instinct MI210, MI250, MI250X -gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT - -There are also [cloud options](https://aws.amazon.com/ec2/instance-types/g4/) you can find if you don't have an AMD GPU at hand. - -## Build the library - -First let's clone the library and rebase to the tested version: - -``` -git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git -cd composable_kernel/ -git checkout tutorial_hello_world -``` - -To make our lives easier we prepared [docker images](https://hub.docker.com/r/rocm/composable_kernel) with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version. - -If your current folder is ${HOME}, start the docker container with - -``` -docker run \ --it \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${HOME}:/root/workspace \ -rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ -/bin/bash -``` - -If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure. - -Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library - -``` -cd composable_kernel/ -``` - -Create and go to the "build" directory - -``` -mkdir build && cd build -``` - -In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag - -``` -cmake \ --D CMAKE_PREFIX_PATH=/opt/rocm \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3" \ --D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=OFF \ --D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. -``` - -If everything went well the cmake run will end up with: - -``` --- Configuring done --- Generating done --- Build files have been written to: "/root/workspace/composable_kernel/build" -``` - -Finally, we can build examples and tests - -``` -make -j examples tests -``` - -If everything is smooth, you'll see - -``` -Scanning dependencies of target tests -[100%] Built target tests -``` - -## Run examples and tests - -Examples are listed as test cases as well, so we can run all examples and tests with - -``` -ctest -``` - -You can check the list of all tests by running - -``` -ctest -N -``` - -We can also run them separately, here is a separate example execution. - -``` -./bin/example_gemm_xdl_fp16 1 1 1 -``` - -The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. - -If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 -``` - -Meanwhile, running it on a gfx1030 device should result in - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} -b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem -``` - -But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like - -``` -./bin/example_gemm_dl_fp16 1 1 1 -``` - -and it should result in something nice similar to - -``` -a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} -b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} -c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} -arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} -arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} -arg.c_grid_desc_m_n_{ 3840, 4096} -launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} -Warm up 1 time -Start running 10 times... -Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> -``` - -Or we can run a separate test - -``` -ctest -R test_gemm_fp16 -``` - -If everything goes well you should see something like - -``` -Start 121: test_gemm_fp16 -1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec - -100% tests passed, 0 tests failed out of 1 -``` - -## Summary - -In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task. - -P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure! diff --git a/docs/Doxyfile b/docs/.doxygen/Doxyfile similarity index 99% rename from docs/Doxyfile rename to docs/.doxygen/Doxyfile index 958b3b6f44..1084f94c81 100644 --- a/docs/Doxyfile +++ b/docs/.doxygen/Doxyfile @@ -51,7 +51,7 @@ PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and # pixels and the maximum width should not exceed 200 pixels. Doxygen will copy # the logo to the output directory. -PROJECT_LOGO = ./rocm.jpg +PROJECT_LOGO = # The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path # into which the generated documentation will be written. If a relative path is @@ -775,8 +775,10 @@ WARN_LOGFILE = # spaces. See also FILE_PATTERNS and EXTENSION_MAPPING # Note: If this tag is empty the current directory is searched. -INPUT = ../library/include \ - ../library/include/internal +INPUT = ../../include/ck/tensor_operation/gpu/grid \ + ../../include/ck/tensor_operation/gpu/block \ + ../../include/ck/tensor_operation/gpu/thread \ + ../../library/include/ck/library/utility # This tag can be used to specify the character encoding of the source files # that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses @@ -845,7 +847,7 @@ FILE_PATTERNS = *.c \ # be searched for input files as well. # The default value is: NO. -RECURSIVE = NO +RECURSIVE = YES # The EXCLUDE tag can be used to specify files and/or directories that should be # excluded from the INPUT source files. This way you can easily exclude a diff --git a/docs/.sphinx/_toc.yml.in b/docs/.sphinx/_toc.yml.in new file mode 100644 index 0000000000..ff21248873 --- /dev/null +++ b/docs/.sphinx/_toc.yml.in @@ -0,0 +1 @@ +root: index diff --git a/docs/.sphinx/requirements.in b/docs/.sphinx/requirements.in new file mode 100644 index 0000000000..36a9a45775 --- /dev/null +++ b/docs/.sphinx/requirements.in @@ -0,0 +1,2 @@ +git+https://github.com/RadeonOpenCompute/rocm-docs-core.git +sphinxcontrib-bibtex==2.5.0 diff --git a/docs/.sphinx/requirements.txt b/docs/.sphinx/requirements.txt new file mode 100644 index 0000000000..8618920ea6 --- /dev/null +++ b/docs/.sphinx/requirements.txt @@ -0,0 +1,287 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile requirements.in +# +accessible-pygments==0.0.4 + # via pydata-sphinx-theme +alabaster==0.7.13 + # via sphinx +asttokens==2.2.1 + # via stack-data +attrs==22.2.0 + # via + # jsonschema + # jupyter-cache +babel==2.12.1 + # via + # pydata-sphinx-theme + # sphinx +backcall==0.2.0 + # via ipython +beautifulsoup4==4.12.0 + # via pydata-sphinx-theme +breathe==4.34.0 + # via rocm-docs-core +certifi==2022.12.7 + # via requests +cffi==1.15.1 + # via pynacl +charset-normalizer==3.1.0 + # via requests +click==8.1.3 + # via + # jupyter-cache + # sphinx-external-toc +comm==0.1.3 + # via ipykernel +debugpy==1.6.6 + # via ipykernel +decorator==5.1.1 + # via ipython +deprecated==1.2.13 + # via pygithub +docutils==0.16 + # via + # breathe + # myst-parser + # pybtex-docutils + # pydata-sphinx-theme + # rocm-docs-core + # sphinx + # sphinxcontrib-bibtex +executing==1.2.0 + # via stack-data +fastjsonschema==2.16.3 + # via nbformat +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via rocm-docs-core +greenlet==2.0.2 + # via sqlalchemy +idna==3.4 + # via requests +imagesize==1.4.1 + # via sphinx +importlib-metadata==6.1.0 + # via + # jupyter-cache + # myst-nb +importlib-resources==5.10.4 + # via rocm-docs-core +ipykernel==6.22.0 + # via myst-nb +ipython==8.11.0 + # via + # ipykernel + # myst-nb +jedi==0.18.2 + # via ipython +jinja2==3.1.2 + # via + # myst-parser + # sphinx +jsonschema==4.17.3 + # via nbformat +jupyter-cache==0.5.0 + # via myst-nb +jupyter-client==8.1.0 + # via + # ipykernel + # nbclient +jupyter-core==5.3.0 + # via + # ipykernel + # jupyter-client + # nbformat +latexcodec==2.0.1 + # via pybtex +linkify-it-py==1.0.3 + # via myst-parser +markdown-it-py==2.2.0 + # via + # mdit-py-plugins + # myst-parser +markupsafe==2.1.2 + # via jinja2 +matplotlib-inline==0.1.6 + # via + # ipykernel + # ipython +mdit-py-plugins==0.3.5 + # via myst-parser +mdurl==0.1.2 + # via markdown-it-py +myst-nb==0.17.1 + # via rocm-docs-core +myst-parser[linkify]==0.18.1 + # via + # myst-nb + # rocm-docs-core +nbclient==0.5.13 + # via + # jupyter-cache + # myst-nb +nbformat==5.8.0 + # via + # jupyter-cache + # myst-nb + # nbclient +nest-asyncio==1.5.6 + # via + # ipykernel + # nbclient +packaging==23.0 + # via + # ipykernel + # pydata-sphinx-theme + # sphinx +parso==0.8.3 + # via jedi +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython +platformdirs==3.1.1 + # via jupyter-core +prompt-toolkit==3.0.38 + # via ipython +psutil==5.9.4 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pybtex==0.24.0 + # via + # pybtex-docutils + # sphinxcontrib-bibtex +pybtex-docutils==1.0.2 + # via sphinxcontrib-bibtex +pycparser==2.21 + # via cffi +pydata-sphinx-theme==0.13.1 + # via sphinx-book-theme +pygithub==1.57 + # via rocm-docs-core +pygments==2.14.0 + # via + # accessible-pygments + # ipython + # pydata-sphinx-theme + # sphinx +pyjwt==2.6.0 + # via pygithub +pynacl==1.5.0 + # via pygithub +pyrsistent==0.19.3 + # via jsonschema +python-dateutil==2.8.2 + # via jupyter-client +pyyaml==6.0 + # via + # jupyter-cache + # myst-nb + # myst-parser + # pybtex + # sphinx-external-toc +pyzmq==25.0.2 + # via + # ipykernel + # jupyter-client +requests==2.28.2 + # via + # pygithub + # sphinx +rocm-docs-core @ git+https://github.com/RadeonOpenCompute/rocm-docs-core.git + # via -r requirements.in +six==1.16.0 + # via + # asttokens + # latexcodec + # pybtex + # python-dateutil +smmap==5.0.0 + # via gitdb +snowballstemmer==2.2.0 + # via sphinx +soupsieve==2.4 + # via beautifulsoup4 +sphinx==4.3.1 + # via + # breathe + # myst-nb + # myst-parser + # pydata-sphinx-theme + # rocm-docs-core + # sphinx-book-theme + # sphinx-copybutton + # sphinx-design + # sphinx-external-toc + # sphinx-notfound-page + # sphinxcontrib-bibtex +sphinx-book-theme==1.0.0rc2 + # via rocm-docs-core +sphinx-copybutton==0.5.1 + # via rocm-docs-core +sphinx-design==0.3.0 + # via rocm-docs-core +sphinx-external-toc==0.3.1 + # via rocm-docs-core +sphinx-notfound-page==0.8.3 + # via rocm-docs-core +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-bibtex==2.5.0 + # via + # -r requirements.in + # rocm-docs-core +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +sqlalchemy==1.4.47 + # via jupyter-cache +stack-data==0.6.2 + # via ipython +tabulate==0.9.0 + # via jupyter-cache +tornado==6.2 + # via + # ipykernel + # jupyter-client +traitlets==5.9.0 + # via + # comm + # ipykernel + # ipython + # jupyter-client + # jupyter-core + # matplotlib-inline + # nbclient + # nbformat +typing-extensions==4.5.0 + # via + # myst-nb + # myst-parser +uc-micro-py==1.0.1 + # via linkify-it-py +urllib3==1.26.15 + # via requests +wcwidth==0.2.6 + # via prompt-toolkit +wrapt==1.15.0 + # via deprecated +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/docs/API_Reference_Guide.rst b/docs/API_Reference_Guide.rst new file mode 100644 index 0000000000..b59c6e3026 --- /dev/null +++ b/docs/API_Reference_Guide.rst @@ -0,0 +1,52 @@ + +******************* +API Reference Guide +******************* + +================= +Introduction +================= + +This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design +principles that are used to write new classes that extend CK functionality. + +================= +Using CK API +================= + +This section describes how to use the CK library API. + +================= +CK Datatypes +================= + +----------------- +DeviceMem +----------------- + +.. doxygenstruct:: DeviceMem + +--------------------------- +Kernels For Flashattention +--------------------------- + +The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are +used in the CK GPU implementation of Flashattention. + +**Gridwise classes** + +.. doxygenstruct:: ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle + +**Blockwise classes** + +.. doxygenstruct:: ck::ThreadGroupTensorSliceTransfer_v4r1 + +.. doxygenstruct:: ck::BlockwiseGemmXdlops_v2 + +.. doxygenstruct:: ck::BlockwiseSoftmax + +**Threadwise classes** + +.. doxygenstruct:: ck::ThreadwiseTensorSliceTransfer_StaticToStatic + +.. bibliography:: diff --git a/docs/source/Contributors_Guide.rst b/docs/Contributors_Guide.rst similarity index 100% rename from docs/source/Contributors_Guide.rst rename to docs/Contributors_Guide.rst diff --git a/docs/source/Supported_Primitives_Guide.rst b/docs/Supported_Primitives_Guide.rst similarity index 99% rename from docs/source/Supported_Primitives_Guide.rst rename to docs/Supported_Primitives_Guide.rst index 066e024bc0..4c3adf67d7 100644 --- a/docs/source/Supported_Primitives_Guide.rst +++ b/docs/Supported_Primitives_Guide.rst @@ -72,4 +72,4 @@ Else if :math:`j>1`, \tilde{Y}_{ij} &= \diag(z^{new}_{i})^{-1} \exp(\tilde{m}_{ij} - m^{new}_i ) \tilde{P}_{ij} \\ z_i &= z^{new}_i \\ m_i &= m^{new}_i \\ - \end{align} \ No newline at end of file + \end{align} diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000000..3ec81ee9df --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,25 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +from rocm_docs import ROCmDocs + +docs_core = ROCmDocs("Composable Kernel Documentation") +docs_core.run_doxygen() +docs_core.setup() + +mathjax3_config = { +'tex': { + 'macros': { + 'diag': '\\operatorname{diag}', + } + } +} + +for sphinx_var in ROCmDocs.SPHINX_VARS: + globals()[sphinx_var] = getattr(docs_core, sphinx_var) + +extensions += ['sphinxcontrib.bibtex'] +bibtex_bibfiles = ['refs.bib'] diff --git a/doc/image/ck_component.png b/docs/data/ck_component.png similarity index 100% rename from doc/image/ck_component.png rename to docs/data/ck_component.png diff --git a/doc/image/ck_layer.png b/docs/data/ck_layer.png similarity index 100% rename from doc/image/ck_layer.png rename to docs/data/ck_layer.png diff --git a/docs/source/dockerhub.rst b/docs/dockerhub.rst similarity index 100% rename from docs/source/dockerhub.rst rename to docs/dockerhub.rst diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000000..f4e66c1b51 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,52 @@ +============================ +Composable Kernel User Guide +============================ + +------------ +Introduction +------------ + +This document contains instructions for installing, using, and contributing to Composable Kernel (CK). + +----------- +Methodology +----------- + +Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++. + +CK utilizes two concepts to achieve performance portability and code maintainability: + +* A tile-based programming model +* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation". + +.. image:: data/ck_component.png + :alt: CK Components + +-------------- +Code Structure +-------------- + +Current CK library are structured into 4 layers: + +* "Templated Tile Operators" layer +* "Templated Kernel and Invoker" layer +* "Instantiated Kernel and Invoker" layer +* "Client API" layer + +.. image:: data/ck_layer.png + :alt: CK Layers + +Documentation Roadmap +^^^^^^^^^^^^^^^^^^^^^ +The following is a list of CK documents in the suggested reading order: + +.. toctree:: + :maxdepth: 5 + :caption: Contents: + :numbered: + + tutorial_hello_world + dockerhub + Supported_Primitives_Guide + API_Reference_Guide + Contributors_Guide diff --git a/docs/refs.bib b/docs/refs.bib new file mode 100644 index 0000000000..3c6d775a7e --- /dev/null +++ b/docs/refs.bib @@ -0,0 +1,7 @@ + +@article{dao2022flashattention, + title={Flashattention: Fast and memory-efficient exact attention with io-awareness}, + author={Dao, Tri and Fu, Daniel Y and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + journal={arXiv preprint arXiv:2205.14135}, + year={2022} +} diff --git a/docs/run_doc.sh b/docs/run_doc.sh deleted file mode 100755 index 58b0936c67..0000000000 --- a/docs/run_doc.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -set -eu - -# Make this directory the PWD -cd "$(dirname "${BASH_SOURCE[0]}")" - -# Build doxygen info -bash run_doxygen.sh - -# Build sphinx docs -cd source -make clean -make -e SPHINXOPTS="-t html" html -make latexpdf diff --git a/docs/run_doxygen.sh b/docs/run_doxygen.sh deleted file mode 100755 index f66c038c1b..0000000000 --- a/docs/run_doxygen.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -set -eu - -# Make this directory the PWD -cd "$(dirname "${BASH_SOURCE[0]}")" - -# Build the doxygen info -rm -rf docBin -doxygen Doxyfile diff --git a/docs/source/API_Reference_Guide.rst b/docs/source/API_Reference_Guide.rst deleted file mode 100644 index 1ad2ecd9a9..0000000000 --- a/docs/source/API_Reference_Guide.rst +++ /dev/null @@ -1,23 +0,0 @@ - -=================== -API Reference Guide -=================== - ------------- -Introduction ------------- - -This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design -principles that are used to write new classes that extend CK functionality. - -================= -Using CK API -================= - -This section describes how to use the CK library API. - ------------------ -CK Datatypes ------------------ - -[TODO] \ No newline at end of file diff --git a/docs/source/Disclaimer.rst b/docs/source/Disclaimer.rst deleted file mode 100644 index 5dcff748c8..0000000000 --- a/docs/source/Disclaimer.rst +++ /dev/null @@ -1,13 +0,0 @@ -************ -Disclaimer -************ -------------------------------- -AMD's standard legal Disclaimer -------------------------------- - -The information presented in this document is for informational purposes only and may contain technical inaccuracies, omissions, and typographical errors. The information contained herein is subject to change and may be rendered inaccurate for many reasons, including but not limited to product and roadmap changes, component and motherboard version changes, new model and/or product releases, product differences between differing manufacturers, software changes, BIOS flashes, firmware upgrades, or the like. Any computer system has risks of security vulnerabilities that cannot be completely prevented or mitigated. AMD assumes no obligation to update or otherwise correct or revise this information. However, AMD reserves the right to revise this information and to make changes from time to time to the content hereof without obligation of AMD to notify any person of such revisions or changes. THIS INFORMATION IS PROVIDED 'AS IS." AMD MAKES NO REPRESENTATIONS OR WARRANTIES WITH RESPECT TO THE CONTENTS HEREOF AND ASSUMES NO RESPONSIBILITY FOR ANY INACCURACIES, ERRORS, OR OMISSIONS THAT MAY APPEAR IN THIS INFORMATION. AMD SPECIFICALLY DISCLAIMS ANY IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR ANY PARTICULAR PURPOSE. IN NO EVENT WILL AMD BE LIABLE TO ANY PERSON FOR ANY RELIANCE, DIRECT, INDIRECT, SPECIAL, OR OTHER CONSEQUENTIAL DAMAGES ARISING FROM THE USE OF ANY INFORMATION CONTAINED HEREIN, EVEN IF AMD IS EXPRESSLY ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. AMD, the AMD Arrow logo, Radeon, Ryzen, Epyc, and combinations thereof are trademarks of Advanced Micro Devices, Inc. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. Google(R) is a registered trademark of Google LLC. PCIe(R) is a registered trademark of PCI-SIG Corporation. Linux(R) is the registered trademark of Linus Torvalds in the U.S. and other countries. Ubuntu(R) and the Ubuntu logo are registered trademarks of Canonical Ltd. Other product names used in this publication are for identification purposes only and may be trademarks of their respective companies. (C)2023 Advanced Micro Devices, Inc. All rights reserved. - ----------------------- -Third Party Disclaimer ----------------------- -Third-party content is licensed to you directly by the third party that owns the content and is not licensed to you by AMD. ALL LINKED THIRD-PARTY CONTENT IS PROVIDED "AS IS" WITHOUT A WARRANTY OF ANY KIND. USE OF SUCH THIRD-PARTY CONTENT IS DONE AT YOUR SOLE DISCRETION AND UNDER NO CIRCUMSTANCES WILL AMD BE LIABLE TO YOU FOR ANY THIRD-PARTY CONTENT. YOU ASSUME ALL RISK AND ARE SOLELY RESPONSIBLE FOR ANY DAMAGES THAT MAY ARISE FROM YOUR USE OF THIRD-PARTY CONTENT. diff --git a/docs/source/Linux_Install_Guide.rst b/docs/source/Linux_Install_Guide.rst deleted file mode 100644 index 0e16bb6a98..0000000000 --- a/docs/source/Linux_Install_Guide.rst +++ /dev/null @@ -1,15 +0,0 @@ -===================== -Getting Started Guide -===================== - ------------- -Introduction ------------- - -This document contains instructions for installing, using, and contributing to Composable Kernel (CK). - -Documentation Roadmap -^^^^^^^^^^^^^^^^^^^^^ -The following is a list of CK documents in the suggested reading order: - -[TODO] \ No newline at end of file diff --git a/docs/source/Makefile b/docs/source/Makefile deleted file mode 100644 index bde66ebc25..0000000000 --- a/docs/source/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SPHINXPROJ = CK -SOURCEDIR = . -BUILDDIR = _build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 8968e2fbe6..0000000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,216 +0,0 @@ -"""Copyright (C) 2018-2023 Advanced Micro Devices, Inc. All rights reserved. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- - ies of the Software, and to permit persons to whom the Software is furnished - to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- - PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS - FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR - COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER - IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- - CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -""" - -# -*- coding: utf-8 -*- -# -# Composable Kernel (CK) docuumentation build configuration file, based on -# rocBLAS documentation build configuration file, created by -# sphinx-quickstart on Mon Jan 8 16:34:42 2018. -# -# This file is execfile()d with the current directory set to its -# containing dir. -# -# Note that not all possible configuration values are present in this -# autogenerated file. -# -# All configuration values have a default; values that are commented out -# serve to show the default. - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - -import os -import sys -import subprocess - -read_the_docs_build = os.environ.get('READTHEDOCS', None) == 'True' - -if read_the_docs_build: - subprocess.call('../run_doxygen.sh') - -# -- General configuration ------------------------------------------------ - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = ['sphinx.ext.mathjax', 'breathe'] -breathe_projects = { "CK": "../docBin/xml" } -breathe_default_project = "CK" - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# General information about the project. -project = u'Composable Kernel (CK)' -copyright = u'2018-2023, Advanced Micro Devices' -author = u'Advanced Micro Devices' - -# The version info for the project you're documenting, acts as replacement for -# |version| and |release|, also used in various other places throughout the -# built documents. -# -# The short X.Y version. -#version = u'0.8' -# The full version, including alpha/beta/rc tags. -#release = u'0.8' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = 'en' - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - - -# -- Options for HTML output ---------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -# html_theme = 'alabaster' - -#if read_the_docs_build: -# html_theme = 'default' -#else: -import sphinx_rtd_theme -html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] -html_logo = "rocm_logo.png" - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -html_theme_options = { - 'logo_only': True, - 'display_version': True -} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# This is required for the alabaster theme -# refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars -# html_sidebars = { -# '**': [ -# 'relations.html', # needs 'show_related': True theme option to display -# 'searchbox.html', -# ] -# } - -mathjax3_config = { -'tex': { - 'macros': { - 'diag': '\\operatorname{diag}', - } - } -} - -# -- Options for HTMLHelp output ------------------------------------------ - -# Output file base name for HTML help builder. -htmlhelp_basename = 'CKdoc' - - -# -- Options for LaTeX output --------------------------------------------- - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - 'preamble': r''' -\setcounter{tocdepth}{5} -\newcommand{\diag}{\operatorname{diag}} -''', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'CK.tex', u'Composabl Kernel (CK) Documentation', - u'Advanced Micro Devices', 'manual'), -] - - -# -- Options for manual page output --------------------------------------- - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'ck', u'Composable Kernel (CK) Documentation', - [author], 1) -] - - -# -- Options for Texinfo output ------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'CK', u'Composable Kernel (CK) Documentation', - author, 'CK', 'Composable Kernel for AMD ROCm', - 'Miscellaneous'), -] diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index 68adf58afd..0000000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,16 +0,0 @@ -============================ -Composable Kernel User Guide -============================ - -.. toctree:: - :maxdepth: 5 - :caption: Contents: - :numbered: - - Linux_Install_Guide - tutorial_hello_world - dockerhub - Supported_Primitives_Guide - API_Reference_Guide - Contributors_Guide - Disclaimer \ No newline at end of file diff --git a/docs/source/rocm_logo.png b/docs/source/rocm_logo.png deleted file mode 100644 index ee09dd09c7..0000000000 Binary files a/docs/source/rocm_logo.png and /dev/null differ diff --git a/docs/source/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst similarity index 100% rename from docs/source/tutorial_hello_world.rst rename to docs/tutorial_hello_world.rst diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 7f8fdf35f4..c5a8295188 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -38,7 +38,7 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -if(GPU_TARGETS MATCHES "gfx1100") +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index 48bcca257a..58f965be88 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -19,15 +19,15 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 1343a814ad..16a8211027 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,4 +1,4 @@ add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100") +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) endif() diff --git a/example/04_gemm_add_add_fastgelu/common.hpp b/example/04_gemm_add_add_fastgelu/common.hpp index 3f9375e092..839587c148 100644 --- a/example/04_gemm_add_add_fastgelu/common.hpp +++ b/example/04_gemm_add_add_fastgelu/common.hpp @@ -62,7 +62,7 @@ struct ExecutionConfig final }; inline bool -parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig config) +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) { if(argc == 1) { diff --git a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp index 5e50c14dc2..ba0476b9b9 100644 --- a/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp +++ b/example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_bf16.cpp @@ -7,10 +7,11 @@ using ADataType = BF16; using BDataType = BF16; using AccDataType = F32; using CShuffleDataType = F32; -using D0DataType = BF16; -using D1DataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = BF16; +using D1DataType = BF16; +using DsDataType = ck::Tuple; +using EDataType = BF16; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = F16; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = F16; +using D1DataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = F32; +using CDataType = F32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = F32; +using D1DataType = F32; +using DsDataType = ck::Tuple; +using EDataType = F32; using ALayout = Row; using BLayout = Col; @@ -36,7 +36,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = I4; +using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = I4; +using D1DataType = I4; +using DsDataType = ck::Tuple; +using EDataType = I4; using KernelADataType = I8; using KernelBDataType = I8; @@ -47,7 +48,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; -using EDataType = I8; +using CDataType = I32; // C matrix doesn't exsit in GPU memory, this is used for host verification +using D0DataType = I8; +using D1DataType = I8; +using DsDataType = ck::Tuple; +using EDataType = I8; using ALayout = Row; using BLayout = Col; @@ -36,7 +37,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm c_m_n({M, N}); + Tensor c_m_n({M, N}); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index ca09c48c10..8ea11df9ce 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,2 +1,6 @@ +# dlops +add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) + +# xdlops add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) \ No newline at end of file diff --git a/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp new file mode 100644 index 0000000000..044f3c166a --- /dev/null +++ b/example/14_gemm_quantization/gemm_dl_quantization_int8.cpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + DsDataType, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmDefault, + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main() +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideE = 1024; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1_uz})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1_uz, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 67f6160873..e3080c4a9a 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -4,12 +4,15 @@ add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32 example_grouped_gemm_xdl_fp16 example_grouped_gemm_xdl_bfp16 - example_grouped_gemm_xdl_int8) + example_grouped_gemm_xdl_int8 + example_grouped_gemm_multiple_d_dl_fp16) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp new file mode 100644 index 0000000000..a5c51ceb0c --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_dl_fp16.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device:: + // ##################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ##################| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ##################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleD_Dl< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index dc45db9865..7891812375 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -26,7 +26,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, { split_k = 1; } - + const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< InputLayout>(conv_param); diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index c74294feb0..32a87dd200 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,5 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100") +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index acf9bcdb46..4b0ea4f157 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -16,7 +16,7 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 -if(GPU_TARGETS MATCHES "gfx1100") +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp index 6de1daa3d4..498eda2442 100644 --- a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 8161b1088a..a6888649c0 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -74,8 +74,8 @@ using DeviceConvFwdInstance = 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 1, - 1, + 4, + 2, S<1, 32, 1, 8>, 8>; diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt new file mode 100644 index 0000000000..0a314cd74c --- /dev/null +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -0,0 +1,21 @@ +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) + +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) + +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) + +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) + +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) + +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/40_conv2d_fwd_quantization/common.hpp b/example/40_conv2d_fwd_quantization/common.hpp new file mode 100644 index 0000000000..6ee14d750e --- /dev/null +++ b/example/40_conv2d_fwd_quantization/common.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..5c445d9c50 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" + +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..0ff85f008f --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" + +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..f8f996d17e --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" + +int main() +{ + float scale_z_inv = 0.5f; + const auto out_element_op = OutElementOp{scale_z_inv, ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..3b25fec0c4 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::TanH; +using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" + +int main() +{ + float scale_acc = 0.5f; + float scale_z_inv = 0.5f; + const auto out_element_op = OutElementOp{scale_z_inv, scale_acc, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..a98a1e240b --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perchannel_quantization_int8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_perchannel_quantization_example.inc" + +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_perchannel_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..262594d58b --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_dl_perlayer_quantization_int8.cpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< + NDimSpatial, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + AccDataType, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 4, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<8, 1, 1, 4>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<8, 1, 1, 4>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<2, 1, 128, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + S<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + S<4, 1, 1, 4>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 4>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector + +#include "run_conv2d_fwd_perlayer_quantization_example.inc" + +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..6b22055053 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using CShuffleDataType = AccDataType; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +#include "run_conv2d_fwd_bias_perchannel_quantization_example.inc" + +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_bias_perchannel_quantization_example(out_element_op); +}; diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..1ac8679743 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using BiasDataType = int32_t; +using AccDataType = int32_t; +using CShuffleDataType = AccDataType; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +#include "run_conv2d_fwd_bias_perlayer_quantization_example.inc" + +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_bias_perlayer_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp new file mode 100644 index 0000000000..f28abe5ebc --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perchannel_quantization_int8.cpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using RequantScaleDataType = float; +using AccDataType = int32_t; +using CShuffleDataType = AccDataType; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = ck::tensor_operation::element_wise::Relu; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 8>; + +#include "run_conv2d_fwd_perchannel_quantization_example.inc" + +int main() +{ + const auto out_element_op = OutElementOp{ActivationOp{}}; + run_conv2d_fwd_perchannel_quantization_example(out_element_op); +} diff --git a/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp new file mode 100644 index 0000000000..f468e8adcd --- /dev/null +++ b/example/40_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +using InDataType = int8_t; +using WeiDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = AccDataType; +using OutDataType = int8_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using ActivationOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 64, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 64, 1, 4>, + 16>; + +#include "run_conv2d_fwd_perlayer_quantization_example.inc" + +int main() +{ + float requant_scale = 0.5f; + const auto out_element_op = OutElementOp{requant_scale, ActivationOp{}}; + run_conv2d_fwd_perlayer_quantization_example(out_element_op); +} diff --git a/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc similarity index 75% rename from example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp rename to example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc index 832665edc0..1587c614da 100644 --- a/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc @@ -1,97 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/algorithm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" - -using InDataType = int8_t; -using WeiDataType = int8_t; -using BiasDataType = int32_t; -using RequantScaleDataType = float; -using AccDataType = int32_t; -using CShuffleDataType = int32_t; -using OutDataType = int8_t; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using InElementOp = PassThrough; -using WeiElementOp = PassThrough; -using ActivationOp = ck::tensor_operation::element_wise::Relu; -using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< - NDimSpatial, - InLayout, - WeiLayout, - ck::Tuple, - OutLayout, - InDataType, - WeiDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - ConvSpec, // ConvForwardSpecialization - GemmSpec, // GemmSpecialization - 1, // - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, - 1, - S<1, 64, 1, 4>, - 8>; - +#pragma once template c_host(out_g_n_k_wos_desc); + Tensor c_host(out_g_n_k_wos_desc); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); @@ -257,7 +167,7 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int main() +int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_element_op) { bool do_verification = true; bool time_kernel = true; @@ -268,7 +178,7 @@ int main() 1, // group 4, // batch 64, // output channels - 32, // input chanels + 192, // input chanels {3, 3}, // weight HW {71, 71}, // x HW {2, 2}, // strides @@ -279,7 +189,6 @@ int main() const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{ActivationOp{}}; using InLayout = ck::tensor_layout::convolution::GNHWC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; @@ -312,8 +221,6 @@ int main() const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); - std::cout << out_g_n_k_wos_desc << std::endl; - using deviceOp = DeviceGroupedConvNDFwdInstance -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using InElementOp = PassThrough; -using WeiElementOp = PassThrough; -using ActivationOp = ck::tensor_operation::element_wise::Relu; -using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< - NDimSpatial, - InLayout, - WeiLayout, - ck::Tuple, - OutLayout, - InDataType, - WeiDataType, - AccDataType, - CShuffleDataType, - ck::Tuple, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - ConvSpec, // ConvForwardSpecialization - GemmSpec, // GemmSpecialization - 1, // - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, - 1, - S<1, 64, 1, 4>, - 8>; +#pragma once template c_host(out_g_n_k_wos_desc); + Tensor c_host(out_g_n_k_wos_desc); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); @@ -242,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification, return (pass ? 0 : 1); } -int main() +int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_element_op) { bool do_verification = true; bool time_kernel = true; @@ -253,7 +166,7 @@ int main() 1, // group 4, // batch 64, // output channels - 32, // input chanels + 192, // input chanels {3, 3}, // weight HW {71, 71}, // x HW {2, 2}, // strides @@ -264,7 +177,6 @@ int main() const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - const auto out_element_op = OutElementOp{0.5f, ActivationOp{}}; using InLayout = ck::tensor_layout::convolution::GNHWC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; @@ -294,8 +206,6 @@ int main() const auto out_g_n_k_wos_desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); - std::cout << out_g_n_k_wos_desc << std::endl; - return run_grouped_conv_fwd< ndim_spatial, InDataType, diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc new file mode 100644 index 0000000000..8e75c27746 --- /dev/null +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_grouped_conv_fwd(bool do_verification, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& requant_scale_g_k_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor requant_scale(requant_scale_g_k_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "requant_scale: " << requant_scale.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + in.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + wei.GenerateTensorValue(GeneratorTensor_2{-128, 127}); + requant_scale.GenerateTensorValue(GeneratorTensor_2{0, 1}); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem requant_scale_device_buf(sizeof(RequantScaleDataType) * + requant_scale.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + requant_scale_device_buf.ToDevice(requant_scale.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(requant_scale_g_k_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(requant_scale_g_k_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + {requant_scale_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + {d0_g_n_k_wos_lengths}, + {d0_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + out_element_op(out_host(idx), c_host(idx), requant_scale(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + + pass &= + ck::utils::check_err(out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); + } + + return (pass ? 0 : 1); +} + +int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_element_op) +{ + bool do_verification = true; + bool time_kernel = true; + const ck::index_t ndim_spatial = 2; + + ck::utils::conv::ConvParam conv_param{ + ndim_spatial, // n_dim + 1, // group + 4, // batch + 64, // output channels + 192, // input chanels + {3, 3}, // weight HW + {71, 71}, // x HW + {2, 2}, // strides + {1, 1}, // dilations + {1, 1}, // left_pads + {1, 1} // right_pads + }; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + + using InLayout = ck::tensor_layout::convolution::GNHWC; + using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using RequantScaleLayout = ck::tensor_layout::convolution::G_K; + using OutLayout = ck::tensor_layout::convolution::GNHWK; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + + const auto requant_scale_g_k_desc = + HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + using deviceOp = DeviceGroupedConvNDFwdInstance; + + return run_grouped_conv_fwd(do_verification, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + requant_scale_g_k_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); +} diff --git a/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc similarity index 71% rename from example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp rename to example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 2d46d86655..926c033c58 100644 --- a/example/44_conv2d_fwd_quantization/conv2d_fwd_xdl_perlayer_quantization_int8.cpp +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -1,89 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/algorithm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/utility/convolution_parameter.hpp" -#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" - -using InDataType = int8_t; -using WeiDataType = int8_t; -using AccDataType = int32_t; -using CShuffleDataType = int32_t; -using OutDataType = int8_t; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using InElementOp = PassThrough; -using WeiElementOp = PassThrough; -using ActivationOp = PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using DeviceGroupedConvNDFwdInstance = - ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< - NDimSpatial, - InLayout, - WeiLayout, - ck::Tuple<>, - OutLayout, - InDataType, - WeiDataType, - AccDataType, - CShuffleDataType, - ck::Tuple<>, - OutDataType, - InElementOp, - WeiElementOp, - OutElementOp, - ConvSpec, // ConvForwardSpecialization - GemmSpec, // GemmSpecialization - 1, // - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN - 1, - 1, - S<1, 64, 1, 4>, - 16>; +#pragma once template "; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 47c821171b..c77772faa4 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -824,7 +824,19 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp index 1e2f81915d..1cc30fd9e6 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index bba306646d..0b1db28465 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -866,7 +866,16 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle << "B0Spec" << getTensorSpecializationString(BSpec) << ", " << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " << "CSpec" << getTensorSpecializationString(CSpec) << ", " - << getMaskingSpecializationString(MaskingSpec) << ">"; + << getMaskingSpecializationString(MaskingSpec) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index b1a78dc99b..493822aeb2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -770,7 +770,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(ck::get_device_name() == "gfx1100") + if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") { if constexpr(!(is_same_v || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index 025e736b14..196dc86da1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -789,6 +789,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return true; } + // check if DsLayout is supported + template + static bool CheckDLayout() + { + static bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + static bool IsSupportedArgument(const Argument& arg) { if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || @@ -797,6 +811,26 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle return false; } + // Check supported layouts + // A0 - Row + // B0 - Col + // D0s - Rows + // B1 - Row or Col + // D1s - Rows + // E1 - Row + if(!(is_same_v && + is_same_v && + CheckDLayout() && + (is_same_v || + is_same_v)&&CheckDLayout() && + is_same_v)) + { + return false; + } + return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, arg.b0_grid_desc_n_k_, arg.b1_grid_desc_n_k_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index b8af157a8e..1eaffe705a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -588,6 +588,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle return false; } + if(ck::get_device_name() != "gfx90a" && std::is_same::value) + { + return false; + } + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, arg.ds_grid_desc_m_n_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 9f9967c96c..b65afce8df 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -777,7 +777,19 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 806b0c5925..ea3020663a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -822,7 +822,14 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index ff49d3b82e..e7e2bf3354 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -956,7 +956,19 @@ struct << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp index 4934599ee4..6c4957b9b1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -913,7 +913,19 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 731dfc5ea1..027f1a1954 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -880,7 +880,17 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 5848006931..6278220c2f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -720,7 +720,16 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp index 69c842137e..d52879cd90 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -631,7 +631,19 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp index 5bef0e2a3e..2a2edc29ba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -1567,7 +1567,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ">"; if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp index 7951713938..1fd4b76cec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -1552,7 +1552,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ">"; if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp index b9a64e8c4b..a9a58c8ac4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -862,7 +862,15 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO << NPerBlock << ", " << KPerBlock << ", " << AK1 << ", " - << BK1 + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp index a7bce886f4..9f9fe0f1c9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp @@ -561,7 +561,19 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index e497bac5d7..6368469309 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -51,7 +51,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx1030__)) + defined(__gfx90a__) || defined(__gfx1030__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); @@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index 66c4de7f05..b59bb2b303 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -86,120 +86,84 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); } -#ifdef ENABLE_COLMAJOR - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); } -#endif }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + const index_t K0 = K / K1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); } - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); } }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } template - static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) { - const auto e_grid_desc_m_n = [&]() { + const auto e_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); } }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - e_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - e_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); } static auto MakeDsGridDescriptor_M_N(const std::array& Ms, @@ -512,7 +476,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 0c845ab5b3..2488101484 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -682,6 +682,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD" << " LoopScheduler: " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp index 3f62601f96..f358bd176f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp @@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio << NPerBlock << ", " << KPerBlock << ", " << AK1 << ", " - << BK1 + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index dbcceac68f..9b0ff3f46f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -12,6 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -78,119 +79,83 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); } -#ifdef ENABLE_COLMAJOR - else if constexpr(is_same::value) + else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } -#endif - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); } }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto c_grid_desc_m_n = [&]() { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); } }(); - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); } // Gridwise descriptor, mapping to whole given provblem. @@ -439,7 +404,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp index 1075247833..a5051455bf 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm" << " NumPrefetch: " << NumPrefetch << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp index 3419344b19..7cd0ff72e8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -683,7 +683,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm" << " LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp index 5e5dbca06d..8ee138f827 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -761,7 +761,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator << NPerBlock << ", " << KPerBlock << ", " << AK1 << ", " - << BK1 + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index c85b805f59..1f08cec67e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK{}; static constexpr auto I3 = Number<3>{}; - static constexpr auto K1Number = Number{}; - - static auto - MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto a_grid_desc_m_k = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - const auto a_grid_desc_m_kpad = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(M, PadM)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - a_grid_desc_m_kpad, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto - MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) - { - assert(KPad % (K1 * KBatch) == 0); - - const index_t K0 = KPad / (K1 * KBatch); - - const auto b_grid_desc_k_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); - } - }(); - - const auto b_grid_desc_kpad_n = transform_tensor_descriptor( - b_grid_desc_k_n, - make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - else - { - return transform_tensor_descriptor( - b_grid_desc_kpad_n, - make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - } - } - - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - const auto c_grid_desc_m_n = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - if constexpr(GemmSpec == GemmSpecialization::MNPadding) - { - const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; - const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - - return transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - } - - static auto GetKPad(index_t K, index_t KBatch) - { - const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; - const index_t KPad = KBatch * K0 * K1; - return KPad; - } - - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // GridwiseGemm using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, - InMemoryDataOperationEnum::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, + ALayout, + BLayout, + CLayout, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, + GemmSpec, MPerBlock, NPerBlock, K0PerBlock, @@ -253,236 +114,64 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; - // GridwiseGemm - using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< - BlockSize, - ADataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum::AtomicAdd, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - CShuffleMRepeatPerShuffle, - CShuffleNRepeatPerShuffle, - CBlockTransferScalarPerVector_NWaveNPerXDL, - CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; - - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); - - using Block2CTileMap = typename GridwiseGemm::CBlockClusterAdaptor; - - // Argument - struct Argument : public BaseArgument - { - Argument(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - index_t M01, - index_t N01, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - index_t k_batch) - : p_a_grid_{p_a_grid}, - p_b_grid_{p_b_grid}, - p_c_grid_{p_c_grid}, - a_grid_desc_kbatch_k0_m_k1_{}, - b_grid_desc_kbatch_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - c_element_op_{c_element_op}, - k_batch_{k_batch} - { - int KPad = DeviceGemmXdlSplitKCShuffle::GetKPad(K, k_batch_); - - a_grid_desc_kbatch_k0_m_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeAGridDescriptor_KBatch_K0_M_K1( - M, K, StrideA, k_batch_, KPad); - b_grid_desc_kbatch_k0_n_k1_ = - DeviceGemmXdlSplitKCShuffle::MakeBGridDescriptor_KBatch_K0_N_K1( - K, N, StrideB, k_batch_, KPad); - c_grid_desc_m_n_ = DeviceGemmXdlSplitKCShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); - - block_2_ctile_map_ = - GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); - - if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, - b_grid_desc_kbatch_k0_n_k1_, - c_grid_desc_m_n_, - block_2_ctile_map_)) - { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; - Block2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CElementwiseOperation c_element_op_; - index_t k_batch_; - }; + using Argument = typename GridwiseGemm::Argument; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdlSplitKCShuffle::Argument; - void Print(const Argument& arg) - { - std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " - << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + void Print(const Argument& karg) { karg.Print(); } - std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " - << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) { if(stream_config.log_level_ > 0) { - Print(arg); + Print(karg); } - const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const auto kbatch = karg.k_batch; - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.block_2_ctile_map_)) + if(!GridwiseGemm::CheckValidity(karg)) { throw std::runtime_error( - "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"); + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); } - const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(karg); + const auto K0 = karg.K0; const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); float ave_time = 0; const auto Run = [&](const auto& kernel) { - hipGetErrorString(hipMemset( - arg.p_c_grid_, - 0, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * - sizeof(CDataType))); + if(kbatch > 1) + hipGetErrorString( + hipMemset(karg.p_c_grid, 0, karg.M * karg.N * sizeof(CDataType))); - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_kbatch_k0_m_k1_, - arg.b_grid_desc_kbatch_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); }; if(has_main_k0_block_loop) { if(kbatch == 1) { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -491,37 +180,19 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } else { - const auto kernel = kernel_gemm_xdlops_v2r4r2< - GridwiseGemmAtomicAdd, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; Run(kernel); } @@ -544,12 +215,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK(static_cast(p_a), @@ -615,11 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK"; - // clang-format on - - return str.str(); - } + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index d5346c0092..92c20a308f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -1004,7 +1004,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 << KPerBlock << ", " << AK1 << ", " << BK1 << ", " - << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp index b0681b724e..9529744f31 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp @@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp index 4f98cf23b2..c921c9f1b4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp @@ -1232,7 +1232,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << K0PerBlock << ", " - << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index 263eb7197d..de40d71293 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -1093,7 +1093,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index e245902b6c..9d4b68c0b6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -579,7 +579,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle namespace ctc = tensor_layout::convolution; // check device - if(get_device_name() == "gfx1100") + if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") { if constexpr(!(is_same_v || is_same_v)) { @@ -837,7 +838,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp index 2af764fcd9..02458bf02a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -940,7 +940,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " - << getConvForwardSpecializationString(ConvForwardSpecialization) + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp new file mode 100644 index 0000000000..d424f29920 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -0,0 +1,763 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx1030__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::Run(gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_desc_ptr[group_id].a_grid_desc_k0_m0_m1_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n0_n1_k1_, + gemm_desc_ptr[group_id].ds_grid_desc_m0_m10_m11_n0_n10_n11_, + gemm_desc_ptr[group_id].e_grid_desc_m0_m10_m11_n0_n10_n11_, + gemm_desc_ptr[group_id].block_2_etile_map_, + integral_constant{}, + integral_constant{}); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm +{ + using DeviceOp = DeviceGroupedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + + struct GroupedGemmBlock2ETileMap + { + using Block2ETileMap = + remove_cvref_t; + + GroupedGemmBlock2ETileMap() + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}); + BlockStart_ = -1; + } + + GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_; + }; + + struct GemmKernelArg + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + typename GridwiseGemm::DsGridPointer ds_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for problem definiton + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + grid_size_ = 0; + + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + a_mtx_mraw_kraw_.emplace_back(M, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + typename GridwiseGemm::DsGridPointer p_ds_grid{}; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + + // tensor descriptors for problem definiton + const auto a_grid_desc_k0_m_k1 = + DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + const auto b_grid_desc_k0_n_k1 = + DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n)) + { + + const index_t grid_size_grp = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) + .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_k0_m0_m1_k1 = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); + const auto b_grid_desc_k0_n0_n1_k1 = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1); + const auto ds_grid_desc_m0_m10_m11_n0_n10_n11 = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n); + const auto e_grid_desc_m0_m10_m11_n0_n10_n11 = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmKernelArg{static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t group_count_; + index_t skipped_group_count_; + + // TODO: A,B element op is unused since gridwise_gemm_dl_v1r3 does NOT support prologue + // for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0); + bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + bool all_has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { +#if DEBUG_LOG + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}" + << std::endl; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}" + << std::endl; + + std::cout << ", arg.e_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; +#endif + + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + K0 = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + bool not_all_has_main_k_block_loop_same = + all_has_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(K0); + bool not_all_has_double_tail_k_block_loop_same = + all_has_double_tail_k_block_loop xor + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(not_all_has_main_k_block_loop_same or not_all_has_double_tail_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for [main|double_tail]_k_block_loop! in " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + hipGetErrorString(hipMemcpy(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), + hipMemcpyHostToDevice)); + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = kernel_grouped_gemm_multiple_d_dl; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + }; + + if(all_has_main_k_block_loop && all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(all_has_main_k_block_loop && !all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!all_has_main_k_block_loop && all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + return false; + } + + const std::string device_name = ck::get_device_name(); + + // TODO add newer Navi arch + if(device_name != "gfx906" and device_name != "gfx908" and device_name != "gfx90a" and + device_name != "gfx1030") + { + return false; + } + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_)) + { + return false; + } + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index ed079a3f15..e3795060be 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -382,6 +382,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; index_t grid_size_; }; @@ -600,7 +605,28 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; } // polymorphic @@ -662,7 +688,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 69fa75c3fd..136017c6d1 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" namespace ck { namespace tensor_operation { @@ -280,43 +281,42 @@ struct AddHardswish }; }; -// C = A * B // E = FastGelu(C + D) struct AddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v; - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v); + const float x = c + d; - const float y = GetFastGeLU(type_convert(c) + type_convert(d)); - - e = type_convert(y); + FastGelu{}.template operator()(e, x); } - template - __host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const { - static_assert(is_valid_param_type_v); + const half_t x = c + d; - e = GetFastGeLU(c + type_convert(d)); + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); } }; diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 7f3d450a39..ceb2b665b9 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -16,7 +16,7 @@ namespace element_wise { // Need to ensure compiler will fail if there is no matching candidate, instead of compiler // siliently do implicit type conversion // -// Method 1: +// Example: // // struct ExampleElementwiseOp // { @@ -30,19 +30,6 @@ namespace element_wise { // { // } // }; -// -// Method 2: -// -// template -// struct ExampleElementwiseOp; -// -// template <> -// struct ExampleElementwiseOp -// { -// __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const -// { -// } -// }; struct AddReluAdd { @@ -208,41 +195,74 @@ struct AddMultiply } }; -// C = A * B // E = FastGelu(C + D0 + D1) struct AddAddFastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) - { - const float u = 2.f * x * (0.035677f * x * x + 0.797885f); - const float emu = exp(-u); - const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; - } - - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || std::is_same_v -#endif - ; - template __host__ __device__ constexpr void - operator()(E& e, const C& c, const D0& d0, const D1& d1) const + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v && - is_valid_param_type_v && is_valid_param_type_v); + const float x = c + d0 + d1; - const float y = - GetFastGeLU(type_convert(c) + type_convert(d0) + type_convert(d1)); + FastGelu{}.template operator()(e, x); + } - e = type_convert(y); + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const float& c, const half_t& d0, const half_t& d1) const + { + const float x0_f = c + d0 + d1; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x0_f = c + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const + { + const float x0_f = + type_convert(c) + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); } }; diff --git a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp index 3f2c2f8773..fefa6c793f 100644 --- a/include/ck/tensor_operation/gpu/element/quantization_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -1,15 +1,36 @@ #pragma once #include "ck/utility/data_type.hpp" +// #include "ck/utility/get_id.hpp" namespace ck { namespace tensor_operation { namespace element_wise { +// Y = Sy * Qy +// W = Sw * Qw +// X = Sx * Qx +// B = Sb * Qb = Sw * Sx * Qb +// Where X, W, Y are float32, Qx, Qw, Qy are int8 +// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range +// Qb is int32, scale of B is Sw * Sx for convenient + +// Y = W @ X, where @ is convolution or matrix multiplication +// Sy * Qy = Sw * Qw @ Sx * Qx +// Qy = [(Sw*Sx)/Sy] * Qw @ Qx + // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Activation_Mul_Clamp { + // Convolution + Activation (piecewise linear function) + // If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx) + + // requantScale_ = Sw * Sx / Sz Activation_Mul_Clamp(float requantScale, Activation activationOp) : requantScale_(requantScale), activationOp_(activationOp) { @@ -17,26 +38,66 @@ struct Activation_Mul_Clamp __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const { - float x_fp32 = ck::type_convert(x); - activationOp_(x_fp32, x_fp32); - float y_fp32 = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); - y = ck::type_convert(y_fp32); + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); } - __host__ __device__ constexpr void operator()(float& y, const int32_t& x) const + __device__ constexpr void operator()(int32_t& y, const int32_t& x) const { - // We might type_convert to int8 after lambda in someplace - float x_fp32 = ck::type_convert(x); - activationOp_(x_fp32, x_fp32); - y = math::clamp(requantScale_ * x_fp32, -128.f, 127.f); + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ constexpr void operator()(float& y, const float& x) const + { + // CAUSION - We might float in & float out in reference code + activationOp_(y, x); + y = math::clamp(requantScale_ * y, -128.f, 127.f); } float requantScale_; Activation activationOp_; }; +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) +template +struct Mul_Activation_Mul_Clamp +{ + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float y_fp32 = ck::type_convert(x); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + // Conv Perchannel quantization + Activation function which is piecewise linear function, such as // relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Activation_Mul2_Clamp { @@ -51,13 +112,35 @@ struct Activation_Mul2_Clamp y = ck::type_convert(y_fp32); } + __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const float& requantScale) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + Activation activationOp_; }; // For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) template struct Add_Activation_Mul_Clamp { + // Convolution + bias + // Let Bias = B = Sw * Sx * Qb + // Where Qb is int32 + // Y = W @ X + B + // Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb + // Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb) + + // For activation, Z = Activaiton(Y) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb) Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) : requantScale_(requantScale), activationOp_(activationOp) { @@ -72,6 +155,17 @@ struct Add_Activation_Mul_Clamp y = ck::type_convert(y_fp32); } + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + float requantScale_; Activation activationOp_; }; @@ -92,15 +186,33 @@ struct Add_Activation_Mul2_Clamp y = ck::type_convert(y_fp32); } + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + Activation activationOp_; }; // For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) template struct Add_Mul_Activation_Mul_Clamp { - Add_Mul_Activation_Mul_Clamp(float requantScale1, float requantScale2, Activation activationOp) - : requantScale1_(requantScale1), requantScale2_(requantScale2), activationOp_(activationOp) + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X + B) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) { } @@ -108,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp operator()(int8_t& y, const int32_t& x, const int32_t& bias) const { float y_fp32 = ck::type_convert(x + bias); - y_fp32 = requantScale1_ * y_fp32; + y_fp32 = scaleAcc_ * y_fp32; activationOp_(y_fp32, y_fp32); - y_fp32 = math::clamp(requantScale2_ * y_fp32, -128.f, 127.f); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); y = ck::type_convert(y_fp32); } - float requantScale1_; - float requantScale2_; + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is non piecewise linear function, +// such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy *Qy) != Sy * Activation(Qy) +template +struct Add_Mul2_Activation_Mul_Clamp +{ + Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp) + : scale_z_inv_(scale_z_inv), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; Activation activationOp_; }; diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 2167a79e01..f1f3042ad1 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -11,6 +11,10 @@ namespace ck { namespace tensor_operation { namespace element_wise { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + struct PassThrough { template @@ -200,36 +204,83 @@ struct Relu } }; -// Y = FastGelu(X) +// Fast GeLU +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +// host code use higher accuracy "exp" and "div" +// gpu code use lower accuracy "__expf" and "rcp" function struct FastGelu { - // Fast GeLU - // https://paperswithcode.com/method/gelu - // y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) - __host__ __device__ static constexpr float GetFastGeLU(float x) + template + __host__ void operator()(Y& y, const X& x) const; + + template + __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ void operator()(float& y, const float& x) const { const float u = 2.f * x * (0.035677f * x * x + 0.797885f); const float emu = exp(-u); const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f); - return x * cdf; + + y = x * cdf; } - template - static inline constexpr bool is_valid_param_type_v = - std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - || std::is_same_v -#endif - ; - - template - __host__ __device__ void operator()(Y& y, const X& x) const + // device code, use lower precision "__expf" and "rcp" + template <> + __device__ void operator()(float& y, const float& x) const { - static_assert(is_valid_param_type_v && is_valid_param_type_v); + const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float emu = __expf(-u); - const float tmp_y = GetFastGeLU(type_convert(x)); - y = type_convert(tmp_y); +#if !CK_WORKAROUND_SWDEV_383542 + const float cdf = 0.5f + 0.5f * (2.f * __frcp_rn(1.f + emu) - 1.f); +#else + const float cdf = 0.5f + 0.5f * (2.f * __ocml_native_recip_f32(1.f + emu) - 1.f); +#endif + + y = x * cdf; + } + + template <> + __host__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); } }; @@ -269,6 +320,19 @@ struct Sigmoid int32_t divider_ = 1; }; +struct TanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tanh(x); + }; +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index ce39c4967b..d6d2051113 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -18,6 +18,10 @@ namespace ck { +/** + * @brief Gridwise gemm + softmax + gemm fusion + * + */ template __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n) { const auto M = c_grid_desc_m_n.GetLength(I0); const auto N = c_grid_desc_m_n.GetLength(I1); @@ -238,19 +240,19 @@ struct GridwiseGemmDlMultipleD_km_kn_mn using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); - using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); using DsGridPointer = decltype(MakeDsGridPointer()); template + bool HasDoubleTailKBlockLoop, + typename Block2CTileMap> __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, DsGridPointer p_ds_grid, FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, + void* __restrict__ p_shared_block, const AElementwiseOperation&, const BElementwiseOperation&, const CDEElementwiseOperation& cde_element_op, @@ -399,8 +401,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn constexpr auto b_block_aligned_space_size = math::integer_least_multiple( b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + FloatAB* p_a_block_double = static_cast(p_shared_block); + FloatAB* p_b_block_double = + static_cast(p_shared_block) + 2 * a_block_aligned_space_size; // register allocation for output auto c_thread_buf = make_static_buffer( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 2ce4d8feb3..38edace197 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -54,7 +54,8 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -147,7 +148,8 @@ __global__ void const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; @@ -242,7 +244,8 @@ __global__ void const CDEElementwiseOperation cde_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; GridwiseOp::template Run(p_a_grid, @@ -271,7 +274,7 @@ __global__ void ignore = b_element_op; ignore = cde_element_op; ignore = block_2_ctile_map; -#endif // end of if (defined(__gfx1100__)) +#endif // end of if (defined(__gfx1100__ )) } template < // DataType Family @@ -428,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + constexpr auto max_lds_align = K1; constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -436,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto b_block_space_size_aligned = math::integer_least_multiple( b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - return (a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)); + constexpr auto c_block_space_size_aligned = math::integer_least_multiple( + cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), + max_lds_align); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_space_size_aligned * sizeof(CShuffleDataType)); } // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} @@ -673,7 +684,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // ABDataTypeAdjusted -> ABDataType throughout this file +#if defined(__gfx90a__) + using ABDataTypeAdjusted = + conditional_t, ck::bhalf_t, ABDataType>; +#else + using ABDataTypeAdjusted = ABDataType; +#endif + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABDataType, - ABDataType, + ABDataTypeAdjusted, decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, ABDataType, - ABDataType, + ABDataTypeAdjusted, decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataType, + ABDataTypeAdjusted, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index e9097552c2..d1209636de 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index d70c5180da..1fee302c3c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -49,7 +49,8 @@ __global__ void const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -414,7 +415,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO(p_a_grid, p_b_grid, p_c_grid, - p_shared_block, + p_shared, a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, @@ -184,16 +181,16 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -265,6 +262,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight using GridwiseGemmPipe = remove_cvref_t())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatABAdjusted -> FloatAB throughout this file +#if defined(__gfx90a__) + using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; +#else + using FloatABAdjusted = FloatAB; +#endif + // M0/M1/M1Padding static constexpr auto M1PerBlock = Number{}; static constexpr auto M0PerBlock = Number{}; @@ -606,7 +613,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, + void* __restrict__ p_shared, const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& @@ -667,7 +674,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -697,7 +704,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatAB, - FloatAB, + FloatABAdjusted, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -726,11 +733,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight // sanity check constexpr index_t KPack = - math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + math::max(K1, MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( - p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); // gridwise GEMM pipeline const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); @@ -799,8 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); - void* p_shared = static_cast(p_shared_block); - auto c_block_buf = make_dynamic_buffer( static_cast(p_shared), c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 73d7088bc8..df6dc68aff 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -59,16 +59,16 @@ __global__ void c_element_op, block_2_ctile_map); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = block_2_ctile_map; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -132,6 +132,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using GridwiseGemmPipe = remove_cvref_t())>; + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatABAdjusted -> FloatAB throughout this file +#if defined(__gfx90a__) + using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; +#else + using FloatABAdjusted = FloatAB; +#endif + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() { constexpr auto max_lds_align = K1; @@ -282,7 +292,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1( - static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0_n_k1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index f0ce2e3bdb..03b315200d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -18,61 +18,24 @@ namespace ck { template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum CGlobalMemoryDataOperation> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx940__)) - constexpr index_t shared_block_size = - GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); - __shared__ FloatAB p_shared_block[shared_block_size]; + __shared__ uint8_t p_shared[shared_size]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_c_grid, - static_cast(p_shared_block), - a_b_k0_m_k1_grid_desc, - b_b_k0_n_k1_grid_desc, - c_grid_desc_mblock_mperblock_nblock_nperblock, - a_element_op, - b_element_op, - c_element_op, - c_block_cluster_adaptor); + GridwiseGemm::template Run( + karg, static_cast(p_shared)); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -80,13 +43,13 @@ template {}; // K1 should be Number<...> - static constexpr auto K1 = Number{}; + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; using ThisThreadBlock = ThisThreadBlock; + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0; + index_t k_batch; + + Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0(K0_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0:" << K0 << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return (M + MPerBlock - 1) / MPerBlock * MPerBlock; + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return (N + NPerBlock - 1) / NPerBlock * NPerBlock; + } + + __host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0 = CalculateK0(K, K_Batch); + return K_Batch * K0 * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t N, index_t MPad, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = K1; @@ -179,45 +370,68 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 c_block_size * sizeof(FloatC)); } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) { - static_assert(is_known_at_compile_time>::value, - "wrong! K1 need to be known at compile-time"); - - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && - (NPerBlock % (NRepeat * NPerXDL)) == 0, - "Invalid tuning param!"); - - const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); - const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); - const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); - - if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && - K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && - K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && - KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) - return false; - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) - return false; - - if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) { - return false; + if(!(karg.M % MPerBlock == 0)) + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + return false; + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; } - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0 * K1; + return KPad; + } + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) { const bool has_main_k0_block_loop = K0 > K0PerBlock; @@ -225,8 +439,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 return has_main_k0_block_loop; } + template __host__ __device__ static constexpr auto - MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) { const auto M = c_m_n_grid_desc.GetLength(I0); const auto N = c_m_n_grid_desc.GetLength(I1); @@ -243,10 +458,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 } // return block_id to C matrix tile idx (m0, n0) mapping + template __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( - const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) { - return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( c_m_n_grid_desc, 8, KBatch); } @@ -263,24 +479,25 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 Number{})); } - using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - void* __restrict__ p_shared_block, - const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op, - const CBlockClusterAdaptor& c_block_cluster_adaptor) + template + __device__ static void Run(const Argument& karg, void* __restrict__ p_shared_block) { + const FloatAB* p_a_grid = karg.p_a_grid; + const FloatAB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0, karg.KPadded); + const auto c_grid_desc_m_n = + MakeCGridDescriptor_M_N(karg.M, karg.N, karg.MPadded, karg.NPadded, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const AElementwiseOperation a_element_op = AElementwiseOperation{}; + const BElementwiseOperation b_element_op = BElementwiseOperation{}; + const CElementwiseOperation c_element_op = CElementwiseOperation{}; + const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -290,26 +507,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!c_block_cluster_adaptor.ValidCTileIndex( - make_tuple(block_work_idx[I1], block_work_idx[I2]), - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t k_batch_id = block_work_idx[I0]; + const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(blockIdx.z); // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); // lds max alignment constexpr auto max_lds_align = K1; @@ -445,7 +652,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + struct LStr + { + static std::string Get() { return ""; } + }; + + template <> + struct LStr + { + static std::string Get() { return "R"; } + }; + + template <> + struct LStr + { + static std::string Get() { return "C"; } + }; + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlSplitKCShuffle_" + << getGemmSpecializationString(GemmSpec) << "_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index b0f453b025..16484ddcc5 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4 SrcCoord src_ref_coord_; }; -// Do NOT involve any tensor coordinates with StaticBuffer +/** + * @brief Threadwise data transfer + * + * Do NOT involve any tensor coordinates with StaticBuffer + * + */ template ::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x7fffffff; + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); @@ -1091,7 +1091,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); @@ -1126,7 +1126,7 @@ amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thr constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); @@ -1161,7 +1161,7 @@ amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thr constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 4fc0be1fbd..1f7df70bcd 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } @@ -358,7 +358,13 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, // Ranged input operand __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) { +#if defined(__gfx11__) asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +#else + ignore = a; + ignore = b; + ignore = c; +#endif } } // namespace ck diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index a0e79220e0..bf09142548 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -23,11 +23,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> { // * Inline assembly need to elimate the duplicated data load, compiler won't help you // delete them. - amd_assembly_wmma_f32_16x16x16_f16_w32( - reg_a, reg_b, reg_c.template AsType()(Number<0>{})); - // reg_c.template AsType()(Number<0>{}) = - // __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( reg_a, reg_b, reg_c.template - // AsType()[Number<0>{}]); + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -41,9 +46,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -60,8 +71,14 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -78,9 +95,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -94,6 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( neg_a, @@ -102,6 +126,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -116,8 +145,14 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> template __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -131,9 +166,15 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> template __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -150,8 +191,14 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -168,9 +215,15 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> // opsel usage // false: D0.[0:15] = result // true : D0.[16:31]= result +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; @@ -184,6 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> template __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) { +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( neg_a, @@ -192,6 +246,11 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> bit_cast(reg_b), reg_c.template AsType()[Number<0>{}], clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index e97b932ab4..101061191e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -976,42 +976,94 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float uint32_t int32; } u = {x}; - if(~u.int32 & 0x7f800000) - { - // When the exponent bits are not all 1s, then the value is zero, normal, - // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus - // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). - // This causes the bfloat16's mantissa to be incremented by 1 if the 16 - // least significant bits of the float mantissa are greater than 0x8000, - // or if they are equal to 0x8000 and the least significant bit of the - // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when - // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already - // has the value 0x7f, then incrementing it causes it to become 0x00 and - // the exponent is incremented by one, which is the next higher FP value - // to the unrounded bfloat16 value. When the bfloat16 value is subnormal - // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up - // to a normal value with an exponent of 0x01 and a mantissa of 0x00. - // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, - // incrementing it causes it to become an exponent of 0xFF and a mantissa - // of 0x00, which is Inf, the next higher value to the unrounded value. - u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even - } - else if(u.int32 & 0xffff) - { - // When all of the exponent bits are 1, the value is Inf or NaN. - // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero - // mantissa bit. Quiet NaN is indicated by the most significant mantissa - // bit being 1. Signaling NaN is indicated by the most significant - // mantissa bit being 0 but some other bit(s) being 1. If any of the - // lower 16 bits of the mantissa are 1, we set the least significant bit - // of the bfloat16 mantissa, in order to preserve signaling NaN in case - // the bloat16's mantissa bits are all 0. - u.int32 |= 0x10000; // Preserve signaling NaN - } + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + bool flag0 = ~u.int32 & 0x7f800000; + + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bfloat16's mantissa bits are all 0. + bool flag1 = !flag0 && (u.int32 & 0xffff); + + u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even + u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN return uint16_t(u.int32 >> 16); } +// convert bfp16 to fp16 via fp32 +template <> +inline __host__ __device__ constexpr half_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert fp16 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(half_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int32 via fp32 +template <> +inline __host__ __device__ constexpr int32_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int32 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int32_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + +// convert bfp16 to int8 via fp32 +template <> +inline __host__ __device__ constexpr int8_t type_convert(bhalf_t x) +{ + float x_fp32 = type_convert(x); + + return static_cast(x_fp32); +} + +// convert int8 to bfp16 via fp32 +template <> +inline __host__ __device__ constexpr bhalf_t type_convert(int8_t x) +{ + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + template struct NumericLimits { diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index 0f45ec177a..b65640bfff 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -135,6 +135,28 @@ __device__ void inner_product(const half8_t& a, const h c); } +template <> +__device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) +{ + c += type_convert(a) * type_convert(b); +} + +template <> +__device__ void +inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + template <> __device__ void inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 4febace0b8..a3732b2fe0 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -92,6 +92,15 @@ static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; +static inline __host__ half_t tanh(half_t x) +{ + return static_cast(std::tanh(static_cast(x))); +}; + +static inline __host__ float tanh(float x) { return std::tanh(x); }; + +static inline __host__ double tanh(double x) { return std::tanh(x); }; + // math functions for the HIP kernel, some are implemented by calling hip builtin functions static inline __device__ float abs(float x) { return ::abs(x); }; @@ -172,5 +181,14 @@ static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; +static inline __device__ half_t tanh(half_t x) +{ + return static_cast(::tanhf(static_cast(x))); +}; + +static inline __device__ float tanh(float x) { return ::tanhf(x); }; + +static inline __device__ double tanh(double x) { return ::tanh(x); }; + } // namespace math } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 6210637ad3..f176cb91e0 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>; using F16_Tuple = ck::Tuple; using F16_F16_Tuple = ck::Tuple; +using F64_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; @@ -85,6 +86,7 @@ using GK_GK_Tuple = ck::Tuple; // pointwise functor using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; +using TanH = ck::tensor_operation::element_wise::TanH; using Scale = ck::tensor_operation::element_wise::Scale; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; @@ -93,6 +95,7 @@ using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; +using Gelu = ck::tensor_operation::element_wise::Gelu; template using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; @@ -101,6 +104,10 @@ template using Add_Activation_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +template +using Add_Mul_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + template using Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; @@ -108,6 +115,10 @@ template using Add_Activation_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +template +using Add_Mul2_Activation_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + template struct DeviceOperationInstanceFactory; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index a0cea7e390..c116d999da 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -19,6 +19,7 @@ namespace tensor_operation { namespace device { namespace instance { +// float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); +// double +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances); + // Contraction + Bilinear template && is_same_v && + is_same_v && is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + op_ptrs); + } + } + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index e921ecd47a..e3f07606c2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -19,6 +19,7 @@ namespace tensor_operation { namespace device { namespace instance { +// float void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances); +// double +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances); + // Contraction + Scale template && is_same_v && + is_same_v) + { + if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) + { + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + op_ptrs); + } + } + return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index c64598dadd..fa7462dbd9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -74,18 +74,17 @@ template -struct DeviceOperationInstanceFactory> +struct DeviceOperationInstanceFactory> { using DeviceOp = DeviceGroupedGemm; + PassThrough, + PassThrough, + PassThrough>; static auto GetInstances() { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp new file mode 100644 index 0000000000..eb181ebc7f --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); + +// GroupedGEMM + GELU +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = DeviceGroupedGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp new file mode 100644 index 0000000000..2fd7ce22f7 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Col, Row, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Col, Col, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Row, Row, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Row, Col, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Col, Row, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Col, Col, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Row, Row, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>>& + instances); + +// Layout(A, B, C) = [Row, Col, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>>& + instances); + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGemmMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp index eda81a233c..793dc8d04a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp @@ -18,7 +18,7 @@ namespace device { namespace instance { // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_conv2d_bias_perchannel_quantization_int8_instances( +void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); -void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( +void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( std::vector>>>& instances); +void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +// piecewise activation function template && is_same_v) { if constexpr(is_same_v) - add_device_conv2d_bias_perchannel_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_bias_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances(op_ptrs); + } else if constexpr(is_same_v) - add_device_conv2d_bias_relu_perchannel_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances(op_ptrs); + } + } + } + + return op_ptrs; + } +}; + +// non-piecewise activation function +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + { + add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances(op_ptrs); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp index 1138402638..c570f76750 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -18,7 +18,7 @@ namespace device { namespace instance { // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_conv2d_bias_perlayer_quantization_int8_instances( +void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr>>>& instances); -void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( +void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( std::vector>>>& instances); +void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( + std::vector< + std::unique_ptr>>>& + instances); + +void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +// piecewise activation function template && is_same_v) { if constexpr(is_same_v) - add_device_conv2d_bias_perlayer_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_bias_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances(op_ptrs); + } else if constexpr(is_same_v) - add_device_conv2d_bias_relu_perlayer_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances(op_ptrs); + } + } + } + + return op_ptrs; + } +}; + +// non-piecewise activation function +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v) + { + add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances(op_ptrs); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp index 1a67ce5688..089343fe6f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp @@ -18,7 +18,7 @@ namespace device { namespace instance { // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_conv2d_perchannel_quantization_int8_instances( +void add_device_conv2d_dl_perchannel_quantization_int8_instances( std::vector>>>& instances); -void add_device_conv2d_relu_perchannel_quantization_int8_instances( +void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_perchannel_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( std::vector) { if constexpr(is_same_v) - add_device_conv2d_perchannel_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_perchannel_quantization_int8_instances(op_ptrs); + } else if constexpr(is_same_v) - add_device_conv2d_relu_perchannel_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_relu_perchannel_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances(op_ptrs); + } } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp index 410be4a57b..e570027eb5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp @@ -18,7 +18,7 @@ namespace device { namespace instance { // grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_conv2d_perlayer_quantization_int8_instances( +void add_device_conv2d_dl_perlayer_quantization_int8_instances( std::vector>>>& instances); -void add_device_conv2d_relu_perlayer_quantization_int8_instances( +void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_perlayer_quantization_int8_instances( + std::vector>>>& + instances); + +void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( std::vector) { if constexpr(is_same_v) - add_device_conv2d_perlayer_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_perlayer_quantization_int8_instances(op_ptrs); + } else if constexpr(is_same_v) - add_device_conv2d_relu_perlayer_quantization_int8_instances(op_ptrs); + { + add_device_conv2d_dl_relu_perlayer_quantization_int8_instances(op_ptrs); + add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(op_ptrs); + } } } diff --git a/library/include/ck/library/utility/device_memory.hpp b/library/include/ck/library/utility/device_memory.hpp index 3c4ece4406..87940e1671 100644 --- a/library/include/ck/library/utility/device_memory.hpp +++ b/library/include/ck/library/utility/device_memory.hpp @@ -14,6 +14,10 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) } } +/** + * @brief Container for storing data in GPU device memory + * + */ struct DeviceMem { DeviceMem() = delete; diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index 54d58f362c..c0bc372764 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -100,6 +100,15 @@ struct FillMonotonicSeq return tmp; }); } + + template + auto operator()(ForwardRange&& range) const -> std::void_t()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } }; template @@ -112,6 +121,15 @@ struct FillConstant { std::fill(first, last, value_); } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } }; } // namespace utils diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 00b37d52b4..fd09471958 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -47,7 +47,9 @@ using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16 // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + #if CK_WORKAROUND_SWDEV_388832 DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + #endif DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp index 21da6895e6..291a127a66 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -47,7 +47,9 @@ using device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_ // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + #if CK_WORKAROUND_SWDEV_388832 DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + #endif DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, F32, F16, PassThrough, PassThrough, Scale, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index ffd6a6a7be..d2a0a3d0fb 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_contraction_bilinear_instance + #float device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp + #double + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp new file mode 100644 index 0000000000..093b2f0e98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp new file mode 100644 index 0000000000..0f683e5c28 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp new file mode 100644 index 0000000000..e384993aed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp new file mode 100644 index 0000000000..92e39c173f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index 7ad6605486..31f6a0fcdc 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_contraction_scale_instance + #float device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp + #double + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp new file mode 100644 index 0000000000..0aa927155a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp new file mode 100644 index 0000000000..b84ea274c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp new file mode 100644 index 0000000000..578469997a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp new file mode 100644 index 0000000000..8e5a19313e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; + +void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp index 9b5ff40484..c4680db831 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -26,7 +26,8 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< @@ -35,14 +36,22 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 192, 64, 4, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 192, 4, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 192, 32, 4, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 64, 32, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt new file mode 100644 index 0000000000..648f2146cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_grouped_gemm_fastgelu_instance + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c2f5f00c7a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[k, m] * b[k, n] = e[m, n] +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..476d4ce1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[k, m] * b[n, k] = e[m, n] +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Col, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..1023fa4810 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..6b065c0f82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// a[m, k] * b[n, k] = e[m, n] +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //###################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, FastGelu, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index 9f826afd68..db650dbfbd 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,6 +1,38 @@ -add_instance_library(device_quantization_instance - device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp - device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp - device_conv2d_xdl_perchannel_quantization_int8_instance.cpp - device_conv2d_xdl_perlayer_quantization_int8_instance.cpp +set(CONV2D_PERLAYER_QUANT_SRC + conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp + conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp +) + +set(CONV2D_PERCHANNEL_QUANT_SRC + conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp + conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp +) + +set(CONV2D_BIAS_PERLAYER_QUANT_SRC + conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp + conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp +) + +set(CONV2D_BIAS_PERCHANNEL_QUANT_SRC + conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp + conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp +) + +set(GEMM_QUANT_SRC + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp + gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp +) + +add_instance_library(device_quantization_instance + ${CONV2D_PERLAYER_QUANT_SRC} + ${CONV2D_PERCHANNEL_QUANT_SRC} + ${CONV2D_BIAS_PERLAYER_QUANT_SRC} + ${CONV2D_BIAS_PERCHANNEL_QUANT_SRC} + ${GEMM_QUANT_SRC} ) diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp new file mode 100644 index 0000000000..b231f8c956 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; +template +using S = ck::Sequence; + +using GNHWC = ck::tensor_layout::convolution::GNHWC; +using GKYXC = ck::tensor_layout::convolution::GKYXC; +using GNHWK = ck::tensor_layout::convolution::GNHWK; +using GK = ck::tensor_layout::convolution::G_K; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; +using TanH = ck::tensor_operation::element_wise::TanH; + +using GK_Tuple = ck::Tuple; +using GK_GK_Tuple = ck::Tuple; +using I32_Tuple = ck::Tuple; +using F32_Tuple = ck::Tuple; +using I32_F32_Tuple = ck::Tuple; + +// perlayer +using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +// bias + perlayer +using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Mul_TanH_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; + +// perchannel +using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; +using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; + +// bias + perchannel +using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; +using Add_Mul2_TanH_Mul_Clamp = + ck::tensor_operation::element_wise::Add_Mul2_Activation_Mul_Clamp; + +static constexpr ck::index_t NDimSpatial = 2; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000..ae5c1d7c32 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,118 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_dl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + // dl + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_bias_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + // dl + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_bias_tanh_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + // dl + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..192d5c9a55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_dl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_bias_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp new file mode 100644 index 0000000000..3c4987f159 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_int8_instance.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "conv2d_quantization_common.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +template +using device_grouped_conv2d_dl_int8_instances = + std::tuple< + // ###########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ###########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ###########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ###########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, int8_t, int8_t, DsDatatype, int8_t, int32_t, GNHWC, GKYXC, DsLayout, GNHWK, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector> + >; +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000..d45c1c1ee7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_dl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_dl_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..f4b9479502 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_dl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_dl_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} + +void add_device_conv2d_dl_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_dl_int8_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp new file mode 100644 index 0000000000..b6e8ee1590 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_xdl_bias_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +void add_device_conv2d_xdl_bias_relu_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +void add_device_conv2d_xdl_bias_tanh_perchannel_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..70f92cec3a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,119 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_xdl_bias_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +void add_device_conv2d_xdl_bias_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +void add_device_conv2d_xdl_bias_tanh_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp new file mode 100644 index 0000000000..262ec06b7f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_int8_instance.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "conv2d_quantization_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// clang-format off +template +using device_grouped_conv2d_xdl_int8_instances = + std::tuple < + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector> + >; +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp similarity index 61% rename from library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp rename to library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp index e87e987593..b0f86b5c2e 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -7,66 +7,72 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -void add_device_conv2d_bias_perchannel_quantization_int8_instances( +void add_device_conv2d_xdl_perchannel_quantization_int8_instances( std::vector>>& instances) + Mul2_Clamp>>>& instances) { add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); } -void add_device_conv2d_bias_relu_perchannel_quantization_int8_instances( +void add_device_conv2d_xdl_relu_perchannel_quantization_int8_instances( std::vector>>& instances) + Relu_Mul2_Clamp>>>& instances) { add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); + device_grouped_conv2d_xdl_int8_instances{}); } } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp new file mode 100644 index 0000000000..4812ceecfd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_conv2d_xdl_int8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_conv2d_xdl_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} + +void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_xdl_int8_instances{}); +} +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp deleted file mode 100644 index 06eed76014..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_conv2d_xdl_int8_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_conv2d_bias_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); -} - -void add_device_conv2d_bias_relu_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); - - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - - add_device_operation_instances(instances, - device_conv2d_int8_32Ds_instances{}); -} -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp deleted file mode 100644 index 6904e269f9..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_int8_instance.hpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; -using GK = ck::tensor_layout::convolution::G_K; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Relu = ck::tensor_operation::element_wise::Relu; - -using GK_Tuple = ck::Tuple; -using GK_GK_Tuple = ck::Tuple; -using I32_Tuple = ck::Tuple; -using F32_Tuple = ck::Tuple; -using I32_F32_Tuple = ck::Tuple; - -using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; -using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; - -using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; -using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; - -using Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; -using Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Activation_Mul2_Clamp; - -using Add_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; -using Add_Relu_Mul2_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; - -static constexpr ck::index_t NDimSpatial = 2; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -template -// clang-format off -using device_conv2d_int8_instances = - std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 16> - >; -// clang-format on - -// for conv + multiple of 32 bit Ds. bit of Ds will affect the ScalarPerVector of C -template -// clang-format off -using device_conv2d_int8_32Ds_instances = - std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, DsLayout, GNHWK, int8_t, int8_t, int32_t, int32_t, DsDatatype, int8_t, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8> - >; -// clang-format on - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp deleted file mode 100644 index 5f1aa0c5c7..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_conv2d_xdl_int8_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_conv2d_perchannel_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); -} - -void add_device_conv2d_relu_perchannel_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_32Ds_instances{}); -} -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp deleted file mode 100644 index 83435d8119..0000000000 --- a/library/src/tensor_operation_instance/gpu/quantization/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp +++ /dev/null @@ -1,62 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "device_conv2d_xdl_int8_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -void add_device_conv2d_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); -} - -void add_device_conv2d_relu_perlayer_quantization_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); - add_device_operation_instances( - instances, - device_conv2d_int8_instances{}); -} -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp new file mode 100644 index 0000000000..9cad8d4c8e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_quantization_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + //####################| A| B| Ds| E| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //####################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM| Thread| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //####################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Dl< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +template +using device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + //####################| A| B| Ds| E| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //####################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM| Thread| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //####################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Dl< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +template +using device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + //####################| A| B| Ds| E| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //####################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM| Thread| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //####################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Dl< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +template +using device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + //####################| A| B| Ds| E| AData| BData| AccData| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //####################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM| Thread| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + //####################| | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + //####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Dl< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..ffe1efb80b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Col, Row, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..7f24e5677c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Col, Col, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..06e66cfe03 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Row, Row, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..16635d1e93 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_dl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Row, Col, Row] +void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp new file mode 100644 index 0000000000..dfb8dc29b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_quantization_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +template +using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +template +using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +template +using device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Empty_Tuple, Row, int8_t, int8_t, int32_t, int32_t, Empty_Tuple, int8_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c153cdf9ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Col, Row, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances{}); +#endif +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..f6cd32026f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Col, Col, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances{}); +#endif +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..45fbacc334 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Row, Row, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances{}); +#endif +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..257633fe19 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Layout(A, B, C) = [Row, Col, Row] +void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances{}); +#endif +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp new file mode 100644 index 0000000000..213f42b91b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using Empty_Tuple = ck::Tuple<>; +using Row_Row_Tuple = ck::Tuple; +using Col_Col_Tuple = ck::Tuple; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Relu = ck::tensor_operation::element_wise::Relu; + +using Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; +using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; +using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; + +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp new file mode 100644 index 0000000000..87e6ae44c7 --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_fastgelu_impl.hpp @@ -0,0 +1,280 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fastgelu.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_fastgelu_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs) +{ + + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{}(b_k_n[i]); + break; + default: + ck::utils::FillUniformDistribution{0.0, 1.0}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5, 0.5}(b_k_n[i]); + } + + ck::utils::FillConstant{}(c_m_n_device_results[i]); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::FastGelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + c_device_buf[i]->SetZero(); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + auto p_ds = std::vector>{}; + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})); + + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + bool group_pass = + ck::utils::check_err(c_m_n_device_results[i], c_m_n_host_result); + pass = pass && group_pass; + + std::cout << "group: " << i << " verification result: " << std::boolalpha + << group_pass << std::endl; + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + } + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + if(do_verification) + { + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << std::endl; + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 04f94a0f24..2c28850e65 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -108,11 +108,6 @@ bool profile_grouped_gemm_impl(int do_verification, const auto b_element_op = BElementOp{}; const auto c_element_op = CElementOp{}; - // if(do_verification) - // { - - // } - using DeviceMemPtr = std::unique_ptr; std::vector a_device_buf, b_device_buf, c_device_buf; @@ -285,7 +280,7 @@ bool profile_grouped_gemm_impl(int do_verification, << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; return pass; -} // namespace profiler +} } // namespace profiler } // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index d3ab88a167..c21fff7de9 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -29,6 +29,7 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp + profile_grouped_gemm_fastgelu.cpp ) set(PROFILER_EXECUTABLE ckProfiler) @@ -68,4 +69,5 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_grouped_gemm.cpp b/profiler/src/profile_grouped_gemm.cpp index 65e24bd9cc..871b2edfd5 100644 --- a/profiler/src/profile_grouped_gemm.cpp +++ b/profiler/src/profile_grouped_gemm.cpp @@ -29,6 +29,11 @@ enum struct GemmDataType INT8_INT8_INT8, // 3 }; +#define OP_NAME "grouped_gemm" +#define OP_DESC "Grouped GEMM" + +namespace { + std::vector argToIntArray(char* input) { std::vector out; @@ -45,9 +50,6 @@ std::vector argToIntArray(char* input) return out; } -#define OP_NAME "grouped_gemm" -#define OP_DESC "Grouped GEMM" - int profile_grouped_gemm(int argc, char* argv[]) { if(!(argc == 14)) @@ -166,4 +168,6 @@ int profile_grouped_gemm(int argc, char* argv[]) return 0; } +} // anonymous namespace + REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm); diff --git a/profiler/src/profile_grouped_gemm_fastgelu.cpp b/profiler/src/profile_grouped_gemm_fastgelu.cpp new file mode 100644 index 0000000000..9b6142f015 --- /dev/null +++ b/profiler/src/profile_grouped_gemm_fastgelu.cpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_fastgelu_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 +}; + +#define OP_NAME "grouped_gemm_fastgelu" +#define OP_DESC "Grouped GEMM+FastGelu" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm_fastgelu(int argc, char* argv[]) +{ + if(!(argc == 14)) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=n0, 1=yes)\n"); + printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_grouped_gemm_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_grouped_gemm_fastgelu_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_fastgelu); diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 40f318c789..426f68d443 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -8,7 +8,8 @@ MY_PROJECT_SOURCE=$1 cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ +-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker \ +-save-temps=$PWD" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \ diff --git a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp index 6545b6e566..75f934cc06 100644 --- a/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp @@ -43,7 +43,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test DataType, DataType, DataType>(true, // do_verification - 1, // init_method integer value + 1, // init_method: integer value false, // do_log false, // time_kernel param, @@ -60,9 +60,9 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight, KernelTypes); TYPED_TEST(TestGroupedConvndBwdWeight, Test1D) { this->conv_params.clear(); - this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - this->conv_params.push_back({1, 4, 64, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); - this->conv_params.push_back({1, 4, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); this->template Run<1>(); } @@ -70,11 +70,11 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D) { this->conv_params.clear(); this->conv_params.push_back( - {2, 4, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); this->conv_params.push_back( - {2, 4, 8, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back( - {2, 4, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); this->template Run<2>(); } @@ -82,10 +82,10 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D) { this->conv_params.clear(); this->conv_params.push_back( - {3, 4, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->conv_params.push_back( - {3, 4, 8, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); this->conv_params.push_back( - {3, 4, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); this->template Run<3>(); }