From 274108d6e6345f6a4ef4e5d9b23dbfd3a8c2b2f8 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Mon, 30 Jan 2023 20:03:59 +0100 Subject: [PATCH 01/22] Use defined seed for deterministic test runs. (#562) Co-authored-by: Adam Osewski --- test/grouped_gemm/grouped_gemm_fp16.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index b3f7cca418..f20d750d36 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. #include +#include #include "profiler/profile_grouped_gemm_impl.hpp" @@ -18,7 +19,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template bool TestGroupedGemm() { - int group_count = rand() % 10 + 1; + + std::mt19937 gen(19391); + std::uniform_int_distribution<> distrib(1, 10); + int group_count = distrib(gen); // GEMM shape std::vector gemm_descs; @@ -29,9 +33,9 @@ bool TestGroupedGemm() for(int i = 0; i < group_count; i++) { - Ms.push_back(256 + 256 * (rand() % 10)); - Ns.push_back(256 + 256 * (rand() % 10)); - Ks.push_back(128 + 128 * (rand() % 10)); + Ms.push_back(256 + 256 * distrib(gen)); + Ns.push_back(256 + 256 * distrib(gen)); + Ks.push_back(128 + 128 * distrib(gen)); StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); From ba40c2ce9df97c49a2f554a6f2ae1bd1936f0d33 Mon Sep 17 00:00:00 2001 From: who who who Date: Tue, 31 Jan 2023 10:34:35 +0800 Subject: [PATCH 02/22] remove unused variable (#564) * remove unused variable * format code --- .../gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp index 53942b9952..ff2511fa6e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -185,9 +185,7 @@ struct GridwiseSparseEmbeddingsForwardLayernorm }, Number{}); auto out_data_refs = generate_tie( - [&](auto output_index_) -> auto& { - return acc_thread_buf(Number{}); - }, + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, Number<1>{}); unpack2(emb_elementwise_op, out_data_refs, in_data_refs); }); From afdfef74f786e6e36fcb99108e306a717c00b79b Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 1 Feb 2023 15:56:59 -0600 Subject: [PATCH 03/22] Add the markdown tutorial hello world (#563) * Add the markdown tutorial * Clean up --------- Co-authored-by: Rosty Geyyer --- doc/markdown/tutorial_hello_world.md | 191 +++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 doc/markdown/tutorial_hello_world.md diff --git a/doc/markdown/tutorial_hello_world.md b/doc/markdown/tutorial_hello_world.md new file mode 100644 index 0000000000..297df10b5c --- /dev/null +++ b/doc/markdown/tutorial_hello_world.md @@ -0,0 +1,191 @@ +## CK Hello world + +## Motivation + +This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever. + +During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project. + +## Description + +Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more. + +So how do we (almost) reach the speed of light? CK acceleration abilities are based on: + +* Layered structure. +* Tile-based computation model. +* Tensor coordinate transformation. +* Hardware acceleration use. +* Support of low precision data types including fp16, bf16, int8 and int4. + +If you are excited and need more technical details and benchmarking results - read this awesome blog [post](https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224). + +For more details visit our [github repo](https://github.com/ROCmSoftwarePlatform/composable_kernel). + +## Hardware targets + +CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture + +GPU Target AMD GPU +gfx908 Radeon Instinct MI100 +gfx90a Radeon Instinct MI210, MI250, MI250X +gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT + +There are also [cloud options](https://aws.amazon.com/ec2/instance-types/g4/) you can find if you don't have an AMD GPU at hand. + +## Build the library + +First let's clone the library and rebase to the tested version: + +``` +git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git +cd composable_kernel/ +git checkout tutorial_hello_world +``` + +To make our lives easier we prepared [docker images](https://hub.docker.com/r/rocm/composable_kernel) with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version. + +If your current folder is ${HOME}, start the docker container with + +``` +docker run \ +-it \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${HOME}:/root/workspace \ +rocm/composable_kernel:ck_ub20.04_rocm5.3_release \ +/bin/bash +``` + +If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure. + +Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library + +``` +cd composable_kernel/ +``` + +Create and go to the "build" directory + +``` +mkdir build && cd build +``` + +In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag + +``` +cmake \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_CXX_FLAGS="-O3" \ +-D CMAKE_BUILD_TYPE=Release \ +-D BUILD_DEV=OFF \ +-D GPU_TARGETS="gfx908;gfx90a;gfx1030" .. +``` + +If everything went well the cmake run will end up with: + +``` +-- Configuring done +-- Generating done +-- Build files have been written to: "/root/workspace/composable_kernel/build" +``` + +Finally, we can build examples and tests + +``` +make -j examples tests +``` + +If everything is smooth, you'll see + +``` +Scanning dependencies of target tests +[100%] Built target tests +``` + +## Run examples and tests + +Examples are listed as test cases as well, so we can run all examples and tests with + +``` +ctest +``` + +You can check the list of all tests by running + +``` +ctest -N +``` + +We can also run them separately, here is a separate example execution. + +``` +./bin/example_gemm_xdl_fp16 1 1 1 +``` + +The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change. + +If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like + +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 +``` + +Meanwhile, running it on a gfx1030 device should result in + +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem +``` + +But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like + +``` +./bin/example_gemm_dl_fp16 1 1 1 +``` + +and it should result in something nice similar to + +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {1, 4096} +b_k_n: dim 2, lengths {4096, 4096}, strides {4096, 1} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +arg.a_grid_desc_k0_m0_m1_k1_{2048, 3840, 2} +arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} +arg.c_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {960, 1, 1}, block_dim {256, 1, 1} +Warm up 1 time +Start running 10 times... +Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1> +``` + +Or we can run a separate test + +``` +ctest -R test_gemm_fp16 +``` + +If everything goes well you should see something like + +``` +Start 121: test_gemm_fp16 +1/1 Test #121: test_gemm_fp16 ................... Passed 51.81 sec + +100% tests passed, 0 tests failed out of 1 +``` + +## Summary + +In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task. + +P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure! From f73574ffdd3bb50e6696dfa627797eda7248eb94 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 6 Feb 2023 11:15:45 -0800 Subject: [PATCH 04/22] Fix CI issues. (#572) * switch to recent staging compiler as default for CI * fix the baseline query * roll back sqlalchemy to version 1.4.46 --- Dockerfile | 2 +- Jenkinsfile | 17 +++++++++-------- script/process_perf_data.py | 5 +++-- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index d024f966c5..dd2a97c7bd 100644 --- a/Dockerfile +++ b/Dockerfile @@ -60,7 +60,7 @@ RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb ARG PREFIX=/opt/rocm # Install packages for processing the performance results RUN pip3 install --upgrade pip -RUN pip3 install sqlalchemy +RUN pip3 install sqlalchemy==1.4.46 RUN pip3 install pymysql RUN pip3 install pandas RUN pip3 install setuptools-rust diff --git a/Jenkinsfile b/Jenkinsfile index 7b2e57c140..11996d72bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -550,8 +550,9 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;COMPILER_VERSION=release - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open''' : "" +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true + 0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT="" + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=""''' : "" pipeline { agent none @@ -568,16 +569,16 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.3', - description: 'Specify which ROCM version to use: 5.2.3, or 5.3 (default), etc.') + defaultValue: '5.4.3', + description: 'Specify which ROCM version to use: 5.4.3 (default).') string( name: 'COMPILER_VERSION', - defaultValue: 'release', - description: 'Specify which version of compiler to use: ck-9110, release (default), or amd-stg-open.') + defaultValue: 'amd-stg-open', + description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).') string( name: 'COMPILER_COMMIT', - defaultValue: '', - description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit (default), or use 8a82e4eb7ba28521ba9a9424a0315a8a16590424 commit of amd-stg-open branch.') + defaultValue: '5541927df00eabd6a110180170eca7785d436ee3', + description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.') string( name: 'BUILD_COMPILER', defaultValue: 'hipcc', diff --git a/script/process_perf_data.py b/script/process_perf_data.py index 638e4ef564..e8b8e1458c 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -3,6 +3,7 @@ import os, io, argparse, datetime #import numpy as np import sqlalchemy from sqlalchemy.types import NVARCHAR, Float, Integer +from sqlalchemy import text import pymysql import pandas as pd from sshtunnel import SSHTunnelForwarder @@ -141,8 +142,8 @@ def parse_logfile(logfile): def get_baseline(table, connection): - query = '''SELECT * from '''+table+''' WHERE Datetime = (SELECT MAX(Datetime) FROM '''+table+''' where Branch_ID='develop' );''' - return pd.read_sql_query(query, connection) + query = text('''SELECT * from '''+table+''' WHERE Datetime = (SELECT MAX(Datetime) FROM '''+table+''' where Branch_ID='develop' );''') + return pd.read_sql(query, connection) def store_new_test_result(table_name, test_results, testlist, branch_name, node_id, gpu_arch, compute_units, rocm_vers, hip_vers, environment, connection): params=[str(branch_name),str(node_id),str(gpu_arch),compute_units,str(rocm_vers),str(hip_vers),str(environment),str(datetime.datetime.now())] From bb3d9546f186e39cefedc3e7f01d88924ba20168 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 8 Feb 2023 09:50:09 -0800 Subject: [PATCH 05/22] Fix a couple more CI issues. (#578) * test the QA cron parameter for compiler commit * create separate dockers for latest and fixed amd-stg-open compiler versions * change groovy syntax * apply cron timers back to develop branch --- Jenkinsfile | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 11996d72bd..4e4c7ad8fa 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -19,7 +19,14 @@ def runShell(String command){ } def getDockerImageName(){ - def img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + def img + if (params.COMPILER_COMMIT == ""){ + img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" + } + else{ + def commit = "${params.COMPILER_COMMIT}"[0..6] + img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}" + } return img } @@ -551,8 +558,8 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true - 0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT="" - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=""''' : "" + 0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT= + 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" pipeline { agent none From 332ccc3367cc48fc0bdcd696574e03943495c24a Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 9 Feb 2023 04:34:45 +0800 Subject: [PATCH 06/22] Add GemmAddSoftmaxGemm support for MSFT ORT (instances and client API) (#576) * add instance for gemm bias softmax gemm * add client example * change CGridDesc_G_M_N to CGridDesc_G_M_O * add gridwise * change c grid name * device add d0s data * fix 08 client_example * add example 47_fused_attention * example output correct * add d0 to example * add d0 element op * rechange instance code * change Acc0ElementwiseOperation to C0DEElementwiseOperation * change example name * update instance for cdeelementwiseop * add bhalf_t ScaleAdd * add test * not surport geem1 bias * remove some ignore * fix test bug --- .../08_fused_attention/CMakeLists.txt | 3 + .../fused_attention_bias.cpp | 226 +++ .../CMakeLists.txt | 1 + .../gemm_bias_softmax_gemm_permute.cpp | 408 +++++ ...vice_batched_gemm_softmax_gemm_permute.hpp | 8 +- ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 279 ++-- .../element/binary_element_wise_operation.hpp | 32 + ...ultiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 1329 +++++++++++++++++ .../device_operation_instance_factory.hpp | 1 + ...batched_gemm_bias_softmax_gemm_permute.hpp | 190 +++ .../CMakeLists.txt | 2 + ...f16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp | 133 ++ ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 133 ++ ...ed_gemm_bias_softmax_gemm_permute_impl.hpp | 395 +++++ .../CMakeLists.txt | 9 +- ...ed_gemm_bias_softmax_gemm_permute_bf16.cpp | 182 +++ ...ed_gemm_bias_softmax_gemm_permute_fp16.cpp | 182 +++ ...ed_gemm_bias_softmax_gemm_permute_util.hpp | 380 +++++ 18 files changed, 3787 insertions(+), 106 deletions(-) create mode 100644 client_example/08_fused_attention/fused_attention_bias.cpp create mode 100644 example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt create mode 100644 example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp create mode 100644 profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp create mode 100644 test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp create mode 100644 test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp create mode 100644 test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 5cdea72fd9..862b9ed5b7 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,2 +1,5 @@ add_executable(client_fused_attention fused_attention.cpp) target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations) + +add_executable(client_fused_attention_bias fused_attention_bias.cpp) +target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations) diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp new file mode 100644 index 0000000000..3113b78560 --- /dev/null +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -0,0 +1,226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +constexpr static auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +using ADataType = ck::half_t; +using B0DataType = ck::half_t; +using B1DataType = ck::half_t; +using CDataType = ck::half_t; +using D0DataType = ck::half_t; +using AccDataType = float; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main(int argc, char* argv[]) +{ + int G0 = 48; + int G1 = 16; + int M = 1024; + int N = 1024; + int K = 64; + int O = 64; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + // D layout [G0, M, G1, N] + std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; + std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K); + SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N); + SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N); + SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O); + + using DeviceOp = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple, + ck::Tuple<>, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + MaskingSpec>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + AElementOp{}, + B0ElementOp{}, + Acc0ElementOp{1 / sqrtf(K)}, + B1ElementOp{}, + CElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + + sizeof(D0DataType) * M * N) * + G0 * G1; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best instance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + c_device_buf.GetDeviceBuffer(), + std::array{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + AElementOp{}, + B0ElementOp{}, + Acc0ElementOp{1 / sqrtf(K)}, + B1ElementOp{}, + CElementOp{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..d1b3dd4be2 --- /dev/null +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp new file mode 100644 index 0000000000..30c98e534a --- /dev/null +++ b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using B0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using C0DEElementOp = ck::tensor_operation::element_wise::ScaleAdd; +using Acc0ElementOp = ck::tensor_operation::element_wise::PassThrough; +using B1ElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +using S = ck::Sequence; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +constexpr static auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using F16 = ck::half_t; +using F32 = float; +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using D0DataType = F16; +using Acc0BiasDataType = ck::Tuple; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + NumDimG, + NumDimM, + NumDimN, + NumDimK, + NumDimO, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + Acc1BiasDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + C0DEElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecA, + TensorSpecB0, + TensorSpecB1, + TensorSpecC, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec>; // MaskingSpecialization + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + int G0 = 3; + int G1 = 2; + int M = 1024; + int N = 1024; + int K = 64; + int O = 64; + float alpha = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{ + M * G1 * K, K, G1 * K, 1}; // A layout [G0, M, G1, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{ + N * G1 * K, K, G1 * K, 1}; // B0 layout [G0, N, G1, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{ + N * G1 * O, O, 1, G1 * O}; // B1 layout [G0, N, G1, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{ + M * G1 * O, O, G1 * O, 1}; // C layout [G0, M, G1, O] + + // D layout [G0, M, G1, N] + std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; + std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K); + DeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K); + DeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N); + DeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N); + DeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data()); + + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto c0de_element_op = C0DEElementOp{alpha}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + auto argument = device_op.MakeArgument( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + std::array{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + a_element_op, + b0_element_op, + c0de_element_op, + b1_element_op, + c_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error("wrong! this device_op instance does not support this problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t BatchCount = G0 * G1; + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = + (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O + + sizeof(CDataType) * M * O + sizeof(D0DataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + Tensor d0_g_m_n({BatchCount, M, N}); + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { + d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + acc0_g_m_n.ForEach([&](auto&, auto idx) { + c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx)); + }); + // masking + const auto mask = DeviceOpInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + return ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp index ff55519980..bde71806da 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -26,9 +26,9 @@ template struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator { @@ -58,9 +58,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op, B0ElementwiseOperation b0_element_op, - Acc0ElementwiseOperation acc0_element_op, + C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) = 0; + C1DEElementwiseOperation c1de_element_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 59b6af1edb..6c38347381 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -13,7 +13,7 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp" #include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -25,15 +25,17 @@ namespace device { template (compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); + GridwiseGemm::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_b1_grid + b1_batch_offset, p_c_grid + c_batch_offset, + p_d0s_grid, p_shared, a_element_op, b_element_op, - acc_element_op, + c0de_element_op, b1_element_op, - c_element_op, + c1de_element_op, a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, b1_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, block_2_ctile_map, c0_matrix_mask); #else @@ -100,13 +113,14 @@ __global__ void ignore = p_c_grid; ignore = a_element_op; ignore = b_element_op; - ignore = acc_element_op; + ignore = c0de_element_op; ignore = b1_element_op; - ignore = c_element_op; + ignore = c1de_element_op; ignore = a_grid_desc_ak0_m_ak1; ignore = b_grid_desc_bk0_n_bk1; ignore = b1_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; ignore = block_2_ctile_map; ignore = batch_count; ignore = compute_base_ptr_of_batch; @@ -126,15 +140,15 @@ template { static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, "Number of dimension must be greater than 0"); - static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); - static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); // TODO ANT: implement bias combination - static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented"); #if 0 // TODO ANT: use alias @@ -261,14 +275,40 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle Number{}); } + static auto MakeD0sGridDescriptor_M_N( + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i], + acc0_biases_gs_ms_ns_strides[i]); + }, + Number{}); + } + + static auto MakeD0sGridDescriptor_G_M_N( + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i], + acc0_biases_gs_ms_ns_strides[i]); + }, + Number{}); + } + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); - using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); - using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {})); + using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {})); constexpr static auto make_MaskOutPredicate() { @@ -288,11 +328,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, const BGridDesc_G_N_K& b_grid_desc_g_n_k, const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, - const CGridDesc_G_M_N& c_grid_desc_g_m_n) + const C1GridDesc_G_M_N& c1_grid_desc_g_m_n, + const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n) : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), b_grid_desc_g_n_k_(b_grid_desc_g_n_k), b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), - c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n), + d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n) { } @@ -313,32 +355,42 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const { - return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0)); } private: AGridDesc_G_M_K a_grid_desc_g_m_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_; - CGridDesc_G_M_N c_grid_desc_g_m_n_; + C1GridDesc_G_M_N c1_grid_desc_g_m_n_; + D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; }; // GridwiseGemm - using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< ADataType, // TODO: distinguish A/B datatype GemmAccDataType, CShuffleDataType, CDataType, + D0sDataType, AElementwiseOperation, BElementwiseOperation, - AccElementwiseOperation, + C0DEElementwiseOperation, B1ElementwiseOperation, - CElementwiseOperation, + C1DEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, - CGridDesc_M_N, + C1GridDesc_M_N, + D0sGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, @@ -395,8 +447,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const BDataType* p_b_grid, const B1DataType* p_b1_grid, CDataType* p_c_grid, - const std::array p_acc0_biases, - const std::array p_acc1_biases, + const std::array p_acc0_biases, + const std::array p_acc1_biases, const std::vector& a_gs_ms_ks_lengths, const std::vector& a_gs_ms_ks_strides, const std::vector& b_gs_ns_ks_lengths, @@ -405,44 +457,48 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, - const std::array, NumAcc1Bias> + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor>& acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths - const std::array, NumAcc1Bias> + const std::array, NumD1Tensor>& acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, + C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) + C1DEElementwiseOperation c1de_element_op) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_b1_grid_{p_b1_grid}, p_c_grid_{p_c_grid}, + p_d0s_grid_{}, a_grid_desc_ak0_m_ak1_{ DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, b_grid_desc_bk0_n_bk1_{ DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, - c_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides)}, + c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, a_grid_desc_g_m_k_{ Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, b_grid_desc_g_n_k_{ Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, - c_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, - c_gs_ms_gemm1ns_strides)}, - c_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N( + acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - acc_element_op_{acc_element_op}, + c0de_element_op_{c0de_element_op}, b1_element_op_{b1_element_op}, - c_element_op_{c_element_op}, + c1de_element_op_{c1de_element_op}, c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], @@ -456,27 +512,39 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, - batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, - compute_base_ptr_of_batch_{ - a_grid_desc_g_m_k_, b_grid_desc_g_n_k_, b1_grid_desc_g_n_k_, c_grid_desc_g_m_n_} + batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)}, + compute_base_ptr_of_batch_{a_grid_desc_g_m_k_, + b_grid_desc_g_n_k_, + b1_grid_desc_g_n_k_, + c1_grid_desc_g_m_n_, + d0s_grid_desc_g_m_n_} { // TODO ANT: implement bias addition - ignore = p_acc0_biases; ignore = p_acc1_biases; - ignore = acc0_biases_gs_ms_ns_lengths; - ignore = acc0_biases_gs_ms_ns_strides; ignore = acc1_biases_gs_ms_gemm1ns_lengths; ignore = acc1_biases_gs_ms_gemm1ns_strides; + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + // D0 pointer + p_d0s_grid_(i) = static_cast(p_acc0_biases[i]); + }); + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, b1_grid_desc_bk0_n_bk1_, - c_grid_desc_m_n_, + c1_grid_desc_m_n_, block_2_ctile_map_)) { - c_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n_); + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N( + acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}; + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = + GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( + d0s_grid_desc_m_n); } } @@ -491,9 +559,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " << b1_grid_desc_g_n_k_.GetLength(I1) << ", " << b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; - std::cout << "c_grid_desc_g_m_n_: " << c_grid_desc_g_m_n_.GetLength(I0) << ", " - << c_grid_desc_g_m_n_.GetLength(I1) << ", " - << c_grid_desc_g_m_n_.GetLength(I2) << '\n'; + std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", " + << c1_grid_desc_g_m_n_.GetLength(I1) << ", " + << c1_grid_desc_g_m_n_.GetLength(I2) << '\n'; } // pointers @@ -501,18 +569,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const BDataType* p_b_grid_; const B1DataType* p_b1_grid_; CDataType* p_c_grid_; + typename GridwiseGemm::D0sGridPointer p_d0s_grid_; // tensor descriptor AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; - CGridDesc_M_N c_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; AGridDesc_G_M_K a_grid_desc_g_m_k_; BGridDesc_G_N_K b_grid_desc_g_n_k_; B1GridDesc_G_N_K b1_grid_desc_g_n_k_; - CGridDesc_G_M_N c_grid_desc_g_m_n_; - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock_; + C1GridDesc_G_M_N c1_grid_desc_g_m_n_; + D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; + + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; // block-to-c-tile map typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; @@ -520,9 +593,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle // element-wise op AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; - AccElementwiseOperation acc_element_op_; + C0DEElementwiseOperation c0de_element_op_; B1ElementwiseOperation b1_element_op_; - CElementwiseOperation c_element_op_; + C1DEElementwiseOperation c1de_element_op_; // check C0 masking and padding C0MatrixMask c0_matrix_mask_; @@ -551,7 +624,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle } const index_t grid_size = - arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_; // Gemm0_K const auto K = @@ -564,15 +637,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, + typename GridwiseGemm::D0sGridPointer, AElementwiseOperation, BElementwiseOperation, - AccElementwiseOperation, + C0DEElementwiseOperation, B1ElementwiseOperation, - CElementwiseOperation, + C1DEElementwiseOperation, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::B1GridDesc_BK0_N_BK1, - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, typename GridwiseGemm::DefaultBlock2CTileMap, ComputeBasePtrOfStridedBatch, C0MatrixMask, @@ -587,15 +662,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle arg.p_b_grid_, arg.p_b1_grid_, arg.p_c_grid_, + arg.p_d0s_grid_, arg.a_element_op_, arg.b_element_op_, - arg.acc_element_op_, + arg.c0de_element_op_, arg.b1_element_op_, - arg.c_element_op_, + arg.c1de_element_op_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, arg.block_2_ctile_map_, arg.batch_count_, arg.compute_base_ptr_of_batch_, @@ -644,9 +721,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle // TODO ANT: Check if tensor specialization & strides mismatch // Check if C permute dimension matches GEMM + GEMM shape - const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded - const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0); - const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1); + const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded + const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); @@ -696,7 +773,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.b1_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_m_n_, + arg.c1_grid_desc_m_n_, arg.block_2_ctile_map_); } @@ -711,8 +788,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const BDataType* p_b, const B1DataType* p_b1, CDataType* p_c, - const std::array p_acc0_biases, - const std::array p_acc1_biases, + const std::array p_acc0_biases, + const std::array p_acc1_biases, const std::vector& a_gs_ms_ks_lengths, const std::vector& a_gs_ms_ks_strides, const std::vector& b_gs_ns_ks_lengths, @@ -721,17 +798,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, - const std::array, NumAcc1Bias> + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths - const std::array, NumAcc1Bias> + const std::array, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, + C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) + C1DEElementwiseOperation c1de_element_op) { return Argument{p_a, p_b, @@ -753,9 +830,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides a_element_op, b_element_op, - acc_element_op, + c0de_element_op, b1_element_op, - c_element_op}; + c1de_element_op}; } static auto MakeInvoker() { return Invoker{}; } @@ -767,8 +844,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const void* p_b, const void* p_b1, void* p_c, - const std::array p_acc0_biases, - const std::array p_acc1_biases, + const std::array p_acc0_biases, + const std::array p_acc1_biases, const std::vector& a_gs_ms_ks_lengths, const std::vector& a_gs_ms_ks_strides, const std::vector& b_gs_ns_ks_lengths, @@ -777,17 +854,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, - const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, - const std::array, NumAcc1Bias> + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths - const std::array, NumAcc1Bias> + const std::array, NumD1Tensor> acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - AccElementwiseOperation acc_element_op, + C0DEElementwiseOperation c0de_element_op, B1ElementwiseOperation b1_element_op, - CElementwiseOperation c_element_op) override + C1DEElementwiseOperation c1de_element_op) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -809,9 +886,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle acc1_biases_gs_ms_gemm1ns_strides, a_element_op, b_element_op, - acc_element_op, + c0de_element_op, b1_element_op, - c_element_op); + c1de_element_op); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index a4053b1f36..d0ba07e5a9 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -49,6 +49,14 @@ struct Add y = x0 + x1; }; + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 + x1_tmp; + } + template <> __host__ __device__ constexpr void operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const @@ -67,6 +75,30 @@ struct Add }; }; +struct ScaleAdd +{ + __host__ __device__ ScaleAdd(float scale) : scale_(scale) {} + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = scale_ * x0 + ck::type_convert(x1); + }; + + template <> + __host__ __device__ void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + y = scale_ * x0 + ck::type_convert(x1); + }; + + float scale_; +}; + struct Subtract { template diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000..da93f63538 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); + static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); + + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned) * + sizeof(FloatGemmAcc); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const C1GridDesc_M_N& c1_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c1_grid_desc_m_n.GetLength(I0) && Gemm1N == c1_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c1_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N& c1_grid_desc_m_n) + { + const auto M = c1_grid_desc_m_n.GetLength(I0); + const auto N = c1_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c1_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const C1GridDesc_M_N& c1_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c1_grid_desc_m_n); + } + + __device__ static auto GetGemm0WaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + constexpr auto WaveSize = MfmaSelector::selected_mfma.wave_size; + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id) + { + constexpr auto WaveSize = MfmaSelector::selected_mfma.wave_size; + constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + static constexpr auto MakeD0sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + // D0 desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n) + { + const auto M = d0_grid_desc_m_n.GetLength(I0); + const auto N = d0_grid_desc_m_n.GetLength(I1); + + constexpr auto mfma = MfmaSelector::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N4 = mfma.num_input_blks; + constexpr auto N5 = mfma.group_size; + return transform_tensor_descriptor( + d0_grid_desc_m_n, + make_tuple(make_unmerge_transform( + make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)), + make_unmerge_transform( + make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); + } + + // D0s desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sGridPointer p_d0s_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const C0DEElementwiseOperation& c0de_element_op, + const B1ElementwiseOperation& b1_element_op, + const C1DEElementwiseOperation& c1de_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap& block_2_ctile_map, + const C0MatrixMask& c0_matrix_mask) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const auto d0s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d0s_grid[i], + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t gemm1_n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // d0 matrix threadwise copy + constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId + I1, // NBlockID + I1, // MRepeat + I1, // NRepeat + I1, // MWaveId + I1, // NWaveId + I1, // MPerXdl + I1, // NGroupNum + I1, // NInputNum + n4)); // registerNum + + auto d0s_thread_buf = generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + return StaticBuffer< + AddressSpaceEnum::Vgpr, + D0DataType, + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), + true>{}; + }, + Number{}); + + const auto wave_id = GetGemm0WaveIdx(); + const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 + + constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, n2, n4)); + + auto d0s_threadwise_copy = generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + D0DataType, + D0DataType, + decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), + decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + n4, + 1, + false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(block_work_idx[I0], // MBlockId + 0, // NBlockId + 0, // mrepeat + 0, // nrepeat + wave_id[I0], // MWaveId + wave_id[I1], // NWaveId + wave_m_n_id[I1], // MPerXdl + 0, // group + wave_m_n_id[I0], // NInputIndex + 0)); // register number + }, + Number{}); + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + true, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + // + // Blockwise softmax + // + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + + // get acc0 8D thread cluster + constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); + constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); + constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); + constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); + constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); + constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); + constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); + constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); + + // get acc0 thread map + constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_m_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); + constexpr auto thread_slice_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + StaticBuffer + c_thread_buf; + c_thread_buf.Clear(); + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + auto n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) + { + continue; + } + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + // multiple d + if constexpr(NumD0Tensor) + { + static_for<0, MXdlPerWave, 1>{}([&](auto mr) { + static_for<0, NXdlPerWave, 1>{}([&](auto nr) { + static_for<0, n2, 1>{}([&](auto groupid) { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + + static_for<0, n4, 1>{}([&](auto i) { + constexpr index_t c_offset = acc0_thread_desc.CalculateOffset( + make_tuple(mr, nr, groupid, i)); + + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + return d0s_thread_buf[iSrc][i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { + return acc_thread_buf(Number{}); + }, + Number<2>{}); + + unpack2(c0de_element_op, dst_data_refs, src_data_refs); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 0, 1, -NXdlPerWave, 0, 0, 0, 0, 0, 0)); + }); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 1, -MXdlPerWave, 0, 0, 0, 0, 0, 0, 0)); + }); + } + else + { + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { c0de_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + } + + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 8d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + // 8d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + constexpr auto M0 = c_block_lengths[I0]; + constexpr auto N0 = c_block_lengths[I1]; + constexpr auto M1 = c_block_lengths[I2]; + constexpr auto N1 = c_block_lengths[I3]; + constexpr auto M2 = c_block_lengths[I4]; + constexpr auto N2 = c_block_lengths[I5]; + constexpr auto N3 = c_block_lengths[I6]; + constexpr auto N4 = c_block_lengths[I7]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( + Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), + make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto n_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto n_global = n_local + n_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) + { + acc_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + }); + } + + block_sync_lds(); // wait for lds read in gemm0 blockwise gemm + + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + // workaround compiler issue; see ck/ck.hpp + if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && + is_same_v && MPerBlock == 256 && NPerBlock == 128 && + Gemm1NPerBlock == 128) + { + __builtin_amdgcn_sched_barrier(0); + } + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + C1DEElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c1de_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index ad2bfe655b..309f7ca039 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -91,6 +91,7 @@ using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; template using Activation_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp new file mode 100644 index 0000000000..0aa7a5aa3d --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances); + +void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances); + +void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances); + +void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpec>> +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasDataType, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpec>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + Acc0BiasDataType::Size() == 1 && + is_same_v, half_t>) + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + else if(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + Acc0BiasDataType::Size() == 1 && + is_same_v, BF16>) + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + else if(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 76121ffc34..eba248e599 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,5 +1,7 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..f73e3dea84 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, BF16, BF16, BF16, BF16, ck::Tuple, ck::Tuple<>, F32, BF16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); +} + +void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..00b37d52b4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmPadded = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; + +static constexpr auto TensorDefault = ck::tensor_operation::device::TensorSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +template +using device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + // #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec| + // #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + // #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + // #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 64, 32, 8, 8, 2, 32, 32, 1, 8, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 256, 32, 128, 32, 8, 8, 2, 32, 32, 1, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmDefault, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8, MaskingSpec>, + // Padded fallback kernel + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec>, + DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, F32, F16, PassThrough, PassThrough, ScaleAdd, PassThrough, PassThrough, GemmPadded, TensorDefault, TensorDefault, TensorDefault, TensorDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, MaskingSpec> + // clang-format on + >; + +void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskOutUpperTriangle>{}); +} + +void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances( + std::vector< + std::unique_ptr, + ck::Tuple<>, + PassThrough, + PassThrough, + ScaleAdd, + PassThrough, + PassThrough, + MaskingSpecialization::MaskDisabled>>>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances< + 2, + 1, + 1, + 1, + 1, + MaskingSpecialization::MaskDisabled>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp new file mode 100644 index 0000000000..799dccc0ff --- /dev/null +++ b/profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int G0, + int G1, + float alpha = -1.f) + +{ + + using PassThrough = tensor_operation::element_wise::PassThrough; + using ScaleAdd = tensor_operation::element_wise::ScaleAdd; + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using C0DEElementOp = ScaleAdd; + using Acc0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + using AccDataType = float; + using D0DataType = tuple_element_t<0, Acc0BiasesDataType>; + using tensor_operation::device::MaskingSpecialization; + + // Ref Gemm0: various type in, fp32 out + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Softmax: fp32 in, various type out + using ReferenceSoftmaxInstance = + tensor_operation::host::ReferenceSoftmax; + + // Ref Gemm1: various type in, various type out + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + bool pass = true; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + // D layout [G0, M, G1, N] + std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; + std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; + + const int BatchCount = G0 * G1; + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor d0_gs_ms_ns(d0_gs_ms_ns_lengths, d0_gs_ms_ns_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + std::srand(1); // work around test flakiness + switch(init_method) + { + case 0: break; + case 1: + // Still unsure whether this kind of deterministic floating point accurary issue is expected + // or not. May want to try exact same approach as the GPU kernel in the host reference + // GEMM+Softmax+GEMM function to see if the accuracy discrepancy goes away. Until then, + // shrink the input value range as it is less likely to produce errors of around ~1e-3. + // a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + // b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{1}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + d0_gs_ms_ns.GenerateTensorValue(GeneratorTensor_1{1}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d0_device_buf(sizeof(D0DataType) * d0_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + d0_device_buf.ToDevice(d0_gs_ms_ns.mData.data()); + + if(alpha < 0) + { + alpha = 1.f / std::sqrt(K); // usually 1 / sqrt(head_dim) + } + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto c0de_element_op = C0DEElementOp{alpha}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + using DeviceOp = + tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + Acc0BiasesDataType, + ck::Tuple<>, + AElementOp, + B0ElementOp, + C0DEElementOp, + B1ElementOp, + CElementOp, + MaskingSpec>; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + Tensor d0_g_m_n({BatchCount, M, N}); + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + d0_gs_ms_ns.ForEach([&](auto& self, auto idx) { + d0_g_m_n(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + acc0_g_m_n.ForEach([&](auto&, auto idx) { + c0de_element_op(acc0_g_m_n(idx), acc0_g_m_n(idx), d0_g_m_n(idx)); + }); + // mask out upper triangle + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle && idx[1] < idx[2]) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + std::array{ + d0_device_buf.GetDeviceBuffer()}, // std::array p_acc0_biases; + {}, // std::array p_acc1_biases; + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // std::array, + // 1>{acc0_biases_gs_ms_ns_strides}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_lengths}, + {}, // std::array, 1>{acc1_biases_gs_ms_os_strides}, + a_element_op, + b0_element_op, + c0de_element_op, + b1_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O + + sizeof(D0DataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && + std::is_same_v && + std::is_same_v && + std::is_same_v && + std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + pass = pass & ck::utils::check_err(c_gs_ms_os_device_result, + c_gs_ms_os_host_result, + "Error: Incorrect results!", + rtol, + atol); + + if(do_log) + { + LogRangeAsType(std::cout << "a_gs_ms_ks: ", a_gs_ms_ks.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b0_gs_ns_ks : ", b0_gs_ns_ks.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b1_gs_os_ns : ", b1_gs_os_ns.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_gs_ms_os_device_result : ", + c_gs_ms_os_device_result.mData, + ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index f858d9f208..79af2b0d3a 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -5,4 +5,11 @@ add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_ge target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) -add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) \ No newline at end of file +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) + +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) +target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) +target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) +add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp new file mode 100644 index 0000000000..fe65a6fb96 --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_bias_softmax_gemm_permute_util.hpp" + +template +class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16 + : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute +{ +}; + +using I1_t = ck::Number<1>; +using I2_t = ck::Number<2>; + +using MaskDisabled_t = + ck::integral_constant; +using MaskOutUpperTriangle_t = + ck::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple<>, MaskDisabled_t>, + std::tuple, ck::Tuple<>, MaskOutUpperTriangle_t> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes); + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Test_BF16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 3, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 2, 4}, + {128, 128, 136, 128, 4, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 4, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 2, 3}, + {128, 128, 129, 128, 2, 3}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16_IrregularK) +{ + this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, + {256, 64, 160, 64, 1, 16}, + {1024, 1024, 80, 80, 1, 16}, + {1024, 64, 80, 64, 1, 16}, + {4096, 4096, 40, 40, 1, 16}, + {4096, 64, 40, 64, 1, 16}}; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, AdhocTest) +{ + this->lengths_ = std::vector>{ + {49, 49, 64, 64, 4, 6}, + {64, 49, 64, 64, 4, 6}, + {1020, 1020, 64, 128, 4, 6}, + {576, 576, 64, 64, 4, 6}, + }; + this->Run(); +} diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp new file mode 100644 index 0000000000..7235cd1b0b --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp @@ -0,0 +1,182 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_softmax_gemm_permute_util.hpp" + +template +class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 + : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute +{ +}; + +using I1_t = ck::Number<1>; +using I2_t = ck::Number<2>; + +using MaskDisabled_t = + ck::integral_constant; +using MaskOutUpperTriangle_t = + ck::integral_constant; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple, ck::Tuple<>, MaskDisabled_t>, + std::tuple, ck::Tuple<>, MaskOutUpperTriangle_t> + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadM) +{ + this->lengths_ = std::vector>{ + {136, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadN) +{ + this->lengths_ = std::vector>{ + {128, 136, 32, 128, 3, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadK) +{ + this->lengths_ = std::vector>{ + {128, 128, 40, 128, 2, 4}, + {128, 128, 136, 128, 4, 2}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_PadO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 136, 1, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddM) +{ + this->lengths_ = std::vector>{ + {129, 128, 32, 128, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddN) +{ + this->lengths_ = std::vector>{ + {128, 129, 32, 128, 4, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddK) +{ + this->lengths_ = std::vector>{ + {128, 128, 33, 128, 2, 3}, + {128, 128, 129, 128, 2, 3}, + }; + this->Run(); +} + +// If kernel B1Layout is RowMajor, expect not to support odd O size +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO) +{ + this->lengths_ = std::vector>{ + {128, 128, 32, 129, 2, 3}, + }; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK) +{ + this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, + {256, 64, 160, 64, 1, 16}, + {1024, 1024, 80, 80, 1, 16}, + {1024, 64, 80, 64, 1, 16}, + {4096, 4096, 40, 40, 1, 16}, + {4096, 64, 40, 64, 1, 16}}; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 48, 16}, + {256, 256, 128, 128, 48, 16}, + {512, 512, 64, 64, 48, 16}, + {512, 512, 128, 128, 48, 16}, + {1024, 1024, 64, 64, 48, 16}, + {1024, 1024, 128, 128, 48, 16}, + {2048, 2048, 64, 64, 48, 16}, + {2048, 2048, 128, 128, 48, 16}, + {4096, 4096, 64, 64, 48, 16}, + {4096, 4096, 128, 128, 48, 16}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} + +using ck::tensor_operation::device::GemmSpecialization; + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch) +{ + int P = 120; // requires padding + int Q = 128; // do not require padding + + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, Q)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, Q, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, Q, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(Q, P, P, P)); + EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(P, P, P, P)); + // clang-format on +} + +TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMismatch) +{ + // IsSupported(M, N, K, O) + // clang-format off + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 120, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 120)); + // Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 129, 128)); + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 130, 128)); + // Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0 + EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128{}.IsSupported(128, 128, 128, 129)); + // clang-format on +} + +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest) +{ + this->lengths_ = std::vector>{ + {49, 49, 64, 64, 4, 6}, + {64, 49, 64, 64, 4, 6}, + {1020, 1020, 64, 128, 4, 6}, + {576, 576, 64, 64, 4, 6}, + }; + this->Run(); +} diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp new file mode 100644 index 0000000000..af5f0efec3 --- /dev/null +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp @@ -0,0 +1,380 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp" + +using ck::tensor_operation::device::GemmSpecialization; +using ck::tensor_operation::device::MaskingSpecialization; +using ck::tensor_operation::device::TensorSpecialization; + +template +using I = ck::Number; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test +{ + using NumDimGType = std::tuple_element_t<0, Tuple>; + using NumDimMType = std::tuple_element_t<1, Tuple>; + using NumDimNType = std::tuple_element_t<2, Tuple>; + using NumDimKType = std::tuple_element_t<3, Tuple>; + using NumDimOType = std::tuple_element_t<4, Tuple>; + using ADataType = std::tuple_element_t<5, Tuple>; + using B0DataType = std::tuple_element_t<6, Tuple>; + using B1DataType = std::tuple_element_t<7, Tuple>; + using CDataType = std::tuple_element_t<8, Tuple>; + using Acc0BiasDataType = std::tuple_element_t<9, Tuple>; + using Acc1BiasDataType = std::tuple_element_t<10, Tuple>; + using MaskingType = std::tuple_element_t<11, Tuple>; + + std::vector> lengths_ = { + {256, 256, 64, 64, 6, 4}, + {256, 256, 128, 128, 4, 6}, + {512, 512, 64, 64, 3, 2}, + {512, 512, 128, 128, 2, 3}, + {1024, 1024, 64, 64, 3, 1}, + {1024, 1024, 128, 128, 1, 1}, + }; + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int G0, int G1) + { + bool pass = + ck::profiler::profile_batched_gemm_bias_softmax_gemm_permute_impl( + verify_, 2, false, bench_, M, N, K, O, G0, G1); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int G0 = lengths[4]; + int G1 = lengths[5]; + + this->RunSingle(M, N, K, O, G0, G1); + } + } +}; + +template +struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + + template + using S = ck::Sequence; + + using ADataType = F16; + using B0DataType = F16; + using B1DataType = F16; + using AccDataType = float; + using CShuffleDataType = F16; + using CDataType = F16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = ScaleAdd; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + 2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple, + ck::Tuple<>, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecialization::Default, // ATensorSpec + TensorSpecialization::Default, // B0TensorSpec + TensorSpecialization::Default, // B1TensorSpec + TensorSpecialization::Default, // CTensorSpec + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle + + bool IsSupported(int M, int N, int K, int O) + { + const int G0 = 1, G1 = 1; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + // D layout [G0, M, G1, N] + std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; + std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; + + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + std::array{nullptr}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + Acc0ElementOp{1.f}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; + +template +struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128 +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; + + template + using S = ck::Sequence; + + using ADataType = BF16; + using B0DataType = BF16; + using B1DataType = BF16; + using AccDataType = float; + using CShuffleDataType = BF16; + using CDataType = BF16; + + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = ScaleAdd; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + + // static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value; + + using DeviceGemmGemmInstance = + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< + 2, + 1, + 1, + 1, + 1, + ADataType, + B0DataType, + B1DataType, + CDataType, + ck::Tuple, + ck::Tuple<>, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + TensorSpecialization::Default, // ATensorSpec + TensorSpecialization::Default, // B0TensorSpec + TensorSpecialization::Default, // B1TensorSpec + TensorSpecialization::Default, // CTensorSpec + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 128, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 4, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<8, 32, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle + + bool IsSupported(int M, int N, int K, int O) + { + const int G0 = 1, G1 = 1; + + // A layout [G0, M, G1, K] + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1}; + + // B0 layout [G0, N, G1, K] + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1}; + + // B1 layout [G0, N, G1, O] + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O}; + + // C layout [G0, M, G1, O] + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1}; + + // D layout [G0, M, G1, N] + std::vector d0_gs_ms_ns_lengths{G0, G1, M, N}; + std::vector d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1}; + + auto gemm = DeviceGemmGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + static_cast(nullptr), + std::array{nullptr}, // p_acc0_biases + {}, // p_acc1_biases + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ns_ks_lengths, + b0_gs_ns_ks_strides, + b1_gs_os_ns_lengths, + b1_gs_os_ns_strides, + c_gs_ms_os_lengths, + c_gs_ms_os_strides, + std::array, 1>{ + d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths + std::array, 1>{ + d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides + {}, // acc1_biases_gs_ms_os_lengths + {}, // acc1_biases_gs_ms_os_strides + PassThrough{}, // a_element_op + PassThrough{}, // b0_element_op + Acc0ElementOp{1.f}, // acc0_element_op + PassThrough{}, // b1_element_op + PassThrough{}); // c_element_op + + return gemm.IsSupportedArgument(argument); + } +}; From b63accee2be9c067c967a0479550ad3d272ed5a2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 8 Feb 2023 15:25:53 -0800 Subject: [PATCH 07/22] adding the first draft of changelog (#571) * adding the first draft of changelog * second draft of changelog --- CHANGELOG.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..2c3215ae44 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,23 @@ +# Change Log for Composable Kernel + +Full documentation for Composable Kernel is not yet available. + +## CK 0.1.1 for ROCm 5.5.0 + +### Fixed +- Fixed a bug in 6-dimensional kernels (#555). +- Fixed grouped ConvBwdWeight test case failure (#524). + +### Optimizations +- Optimized ... + +### Added +- Added user tutorial (#563). +- Added more instances for irregular GEMM sizes (#560). +- Added inter-wave consumer-producer programming model for GEMM kernels (#310). +- Added multi-D GEMM client APIs (#534). +- Added multi-embeddings support (#542). +- Added Navi3x blockwise GEMM and real GEMM support (#541). + +### Changed +- Changed ... From 76d144fa7c396e52631719f79008e7099b6cd30d Mon Sep 17 00:00:00 2001 From: guangzlu <87220526+guangzlu@users.noreply.github.com> Date: Fri, 10 Feb 2023 01:37:29 +0800 Subject: [PATCH 08/22] Add instance for elementwise normlization (#573) * added instances for large N * add instance for elementwise normlization * added supported restrict in device_elementwise_normalization_impl.hpp --- .../device/impl/device_elementwise_normalization_impl.hpp | 5 +++++ .../device_elementwise_normalization_f16_instance.cpp | 5 +++++ .../test_elementwise_layernorm_fp16.cpp | 2 +- 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp index 1085bdf922..1fa69288a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp @@ -533,6 +533,11 @@ struct DeviceElementwiseNormalizationImpl return (false); } + if(p_arg_->x_lds_size_ >= 65536) + { + return (false); + } + return true; }; diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp index 7f15372ed9..b160d4fe1a 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/device_elementwise_normalization_f16_instance.cpp @@ -23,6 +23,11 @@ template + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>, // fallback kernel for large N + DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 4, 64, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel for large N DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel DeviceElementwiseNormalizationImpl, F16, F16, F32, F16, XElementwise ,YElementwise, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel diff --git a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp index 403881b3cc..e80995c4f0 100644 --- a/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp +++ b/test/elementwise_normalization/test_elementwise_layernorm_fp16.cpp @@ -23,7 +23,7 @@ class TestElementwiseLayernorm : public ::testing::Test { // M, N std::vector> lengths = { - {1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}}; + {1, 1}, {25, 16}, {39, 777}, {100, 200}, {1024, 1024}, {48 * 256, 2048}, {4096, 8192}}; for(auto length : lengths) { From f7d28f3e4b2992ff58169600e914bf0cc72bd756 Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Fri, 10 Feb 2023 05:02:55 +0800 Subject: [PATCH 09/22] Gemm+layernorm instance, ckProfiler, client example (#568) * Add gemm + layernorm instance * Add ckProfiler * Add test * Add client example * Detect if user forger to set the workrspace * Use literal in the example * [What] use builtin function for sqrt [Why] compiler will not use v_sqrt_f64_e64 if we use ::sqrt() * check gemm vaildity in IsSupportedArgument * Add more testcases * Merge duplicated folder in client example * Print more infomation * Use better kernel parameter for MS problem size * clang format * Add constexpr for if condition and remove redundant include * Remove cstdlib and add constexpr --- client_example/01_gemm/gemm.cpp | 2 +- .../gemm_add_add_fastgelu.cpp | 2 +- .../gemm_add_fastgelu.cpp | 2 +- .../gemm_fastgelu.cpp | 2 +- .../03_gemm_layernorm/CMakeLists.txt | 7 +- ...m.cpp => gemm_add_add_layernorm_naive.cpp} | 2 +- .../gemm_add_relu_add_layernorm_welford.cpp | 244 ++++++++++++ .../gemm_add_multiply.cpp | 2 +- ...bias_relu_add_layernorm_xdl_naive_fp16.cpp | 3 +- ...as_relu_add_layernorm_xdl_welford_fp16.cpp | 19 +- .../gemm_layernorm_xdl_naive_fp16.cpp | 3 +- ...xdl_layernorm_naive_single_kernel_fp16.cpp | 2 +- ...gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 16 +- include/ck/utility/math_v2.hpp | 4 +- .../device_operation_instance_factory.hpp | 1 + .../gpu/gemm_add_add_fastgelu.hpp | 1 - .../gpu/gemm_add_relu_add_layernorm.hpp | 172 +++++++++ .../CMakeLists.txt | 6 + ..._layernorm_f16_km_kn_mn_mn_mn_instance.cpp | 130 +++++++ ..._layernorm_f16_km_nk_mn_mn_mn_instance.cpp | 130 +++++++ ..._layernorm_f16_mk_kn_mn_mn_mn_instance.cpp | 130 +++++++ ..._layernorm_f16_mk_nk_mn_mn_mn_instance.cpp | 127 +++++++ ...ofile_gemm_add_relu_add_layernorm_impl.hpp | 346 ++++++++++++++++++ profiler/src/CMakeLists.txt | 3 +- .../profile_gemm_add_relu_add_layernorm.cpp | 215 +++++++++++ test/CMakeLists.txt | 3 +- test/gemm_layernorm/CMakeLists.txt | 7 + .../test_gemm_add_relu_add_layernorm_fp16.cpp | 77 ++++ test/normalization/CMakeLists.txt | 13 +- 29 files changed, 1635 insertions(+), 36 deletions(-) rename client_example/03_gemm_layernorm/{gemm_add_add_layernorm.cpp => gemm_add_add_layernorm_naive.cpp} (99%) create mode 100644 client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp create mode 100644 profiler/src/profile_gemm_add_relu_add_layernorm.cpp create mode 100644 test/gemm_layernorm/CMakeLists.txt create mode 100644 test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index a8a6bf16c2..ba7118ba39 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -83,7 +83,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index f88e72b62e..08f297f58a 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index 512555f978..658c1e9e8f 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index 7237231032..ea269545a5 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -84,7 +84,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 3742e70844..b38698d906 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,2 +1,5 @@ -add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp) -target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations) +add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) +target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations) + +add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) +target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations) diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp similarity index 99% rename from client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp rename to client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index 02da5ff6ce..caa6573788 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -190,7 +190,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp new file mode 100644 index 0000000000..d4f0c2048b --- /dev/null +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// DataType +using ADataType = F16; +using BDataType = F16; +using D0DataType = F16; +using D1DataType = F16; +using GammaDataType = F16; +using BetaDataType = F16; +using HDataType = F16; + +// Layout +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Row; +using HLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddReluAdd; +using HElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size) + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; + std::size_t mMemSize_; +}; + +int main(int argc, char* argv[]) +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideD0 = 0; + ck::index_t StrideD1 = N; + ck::index_t StrideH = N; + + float epsilon = 1e-5; + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{})); + SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{})); + SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * + f_matrix_space_size(M, N, StrideD0, D0Layout{})); + SimpleDeviceMem d1_device_buf(sizeof(D1DataType) * + f_matrix_space_size(M, N, StrideD1, D1Layout{})); + SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N); + SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N); + SimpleDeviceMem h_device_buf(sizeof(HDataType) * f_matrix_space_size(M, N, StrideH, HLayout{})); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm< + ALayout, + BLayout, + ck::Tuple, + HLayout, + ADataType, + BDataType, + ck::Tuple, + GammaDataType, + BetaDataType, + HDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd, + ck::tensor_operation::element_wise::PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + const auto h_element_op = HElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + h_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + h_device_buf.SetZero(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + (sizeof(D0DataType) + sizeof(D1DataType) + sizeof(HDataType)) * M * N + + (sizeof(GammaDataType) + sizeof(BetaDataType)) * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, " + << op_name << std::endl; + + if(ave_time < best_ave_time) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + h_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + SimpleDeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + h_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + + return 0; +} \ No newline at end of file diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index 740f315b8c..28524a9eee 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp index 83b17699a7..192fe87b62 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -116,7 +115,7 @@ auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; - if(std::is_same::value) + if constexpr(std::is_same::value) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } diff --git a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp index b927ae2828..3f01e69477 100644 --- a/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -15,6 +14,7 @@ #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" #include "ck/library/utility/check_err.hpp" @@ -69,21 +69,20 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayern // clang-format on auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { - return HostTensorDescriptor(std::vector({len}), - std::vector({stride})); + return HostTensorDescriptor({len}, {stride}); }; auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if(std::is_same::value) + using namespace ck::literals; + + if constexpr(std::is_same::value) { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({stride, 1})); + return HostTensorDescriptor({row, col}, {stride, 1_uz}); } else { - return HostTensorDescriptor(std::vector({row, col}), - std::vector({1, stride})); + return HostTensorDescriptor({row, col}, {1_uz, stride}); } }; @@ -97,6 +96,7 @@ void host_gemm_layernorm(Tensor& h_m_n, AElementOp a_element_op, BElementOp b_element_op, CDEElementOp cde_element_op, + HElementOp h_element_op, int M, int N, AccDataType epsilon = 1e-5) @@ -145,7 +145,7 @@ void host_gemm_layernorm(Tensor& h_m_n, auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); auto ref_layernorm_argument = ref_layernorm.MakeArgument( - e_m_n, gamma_n, beta_n, h_m_n, HElementOp{}, {M, N}, {1}, epsilon); + e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); ref_layernorm_invoker.Run(ref_layernorm_argument); } @@ -249,6 +249,7 @@ int main() a_element_op, b_element_op, cde_element_op, + h_element_op, M, N, epsilon); diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp index bb4b60cbfe..4da6da65f7 100644 --- a/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_naive_fp16.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -115,7 +114,7 @@ auto f_host_tensor_descriptor2d = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; - if(std::is_same::value) + if constexpr(std::is_same::value) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } diff --git a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp index 3c3e36be6a..e7d857c4a0 100644 --- a/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp +++ b/example/21_gemm_layernorm/gemm_xdl_layernorm_naive_single_kernel_fp16.cpp @@ -135,7 +135,7 @@ int main(int argc, char* argv[]) [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; - if(std::is_same::value) + if constexpr(std::is_same::value) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index 2f4bf3ee0e..b53927a9ed 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -669,6 +669,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle { throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); } + if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr || + arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_); @@ -939,7 +942,11 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle } } - return true; + return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); } // polymorphic @@ -1055,7 +1062,12 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle << GemmKPerBlock << ", " << AK1 << ", " << BK1 << ", " - << getGemmSpecializationString(GemmSpec) + << getGemmSpecializationString(GemmSpec) << ", " + << PostShuffleThreadClusterSize_M_N::At(I0) << ", " + << PostShuffleThreadClusterSize_M_N::At(I1) << ", " + << LayernormThreadClusterSize_M_N::At(I0) << ", " + << LayernormThreadClusterSize_M_N::At(I1) << ", " + << LayernormThreadSliceSize_M << ">" << " LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 7334703241..4aba0b1192 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -158,9 +158,9 @@ static inline __device__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; -static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; +static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; -static inline __device__ double sqrt(double x) { return ::sqrt(x); }; +static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; } // namespace math } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 309f7ca039..6210637ad3 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -89,6 +89,7 @@ using Scale = ck::tensor_operation::element_wise::Scale; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; using FastGelu = ck::tensor_operation::element_wise::FastGelu; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp index 09d8e8b95b..90b6e11b9b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp new file mode 100644 index 0000000000..7beae83cdc --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + std::vector>>&); + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + std::vector>>&); + +// GEMM + Add + Relu + Add + Layernorm +template +struct DeviceOperationInstanceFactory, + HLayout, + ADataType, + BDataType, + ck::Tuple, + GammaDataType, + BetaDataType, + HDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleDLayernorm, + HLayout, + ADataType, + BDataType, + ck::Tuple, + GammaDataType, + BetaDataType, + HDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd, + ck::tensor_operation::element_wise::PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt new file mode 100644 index 0000000000..97693a2566 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -0,0 +1,6 @@ +add_instance_library(device_gemm_add_relu_add_layernorm_instance + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..47b8d23424 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// outout: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances = std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v1>{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances< + LoopScheduler::Interwave, + PipelineVersion::v1>{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v2>{}); +#endif + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..efa030ec49 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// outout: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances = std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Col, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v1>{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances< + LoopScheduler::Interwave, + PipelineVersion::v1>{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v2>{}); +#endif + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..f2735020e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// outout: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances = std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Row, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v1>{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances< + LoopScheduler::Interwave, + PipelineVersion::v1>{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v2>{}); +#endif + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp new file mode 100644 index 0000000000..7d4aae928b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; +using F16_F16_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e = elementwise((a * b), d0, d1) +// h = layernorm(e, gamma, beta) +// outout: h[m, n] +// input: a[k, m], b[k, n], d0[m, n], d1[m, n], gamma[n], beta[n] +template +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances = std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 8>, 8, S<32, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<32, 4>, 8, S<32, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 8>, 8, S<16, 8>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline>, + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<16, 4>, 8, S<16, 4>, 1, GemmLoopScheduler, GemmPipeline> + // clang-format on + >; + +// irregular tile size +using device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //#######################################| A| B| Ds| H| AData| BData| AccData| CShuffle| DsData| EMeanVarData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle| Layernorm| Layernorm| LoopScheduler| Pipeline| + //#######################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ThreadClusterLengths| ScalarPerVector| ThreadClusterLengths| ThreadSliceSize| | | + //#######################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _NWaveNPerXdl| _M_N| _M| | | + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // pipeline v1, 1 wave + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v1> +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + // pipeline v1, 2 waves + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Interwave, PipelineVersion::v1> +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + // pipeline v2, 1 wave + , + DeviceGemmMultipleDLayernorm_Xdl_CShuffle< Row, Col, Row_Row_Tuple, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, F16, F16, F16, PassThrough, PassThrough, AddReluAdd, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<16, 4>, 1, S<16, 4>, 1, LoopScheduler::Default, PipelineVersion::v2> +#endif + // clang-format on + >; + +void add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v1>{}); +#if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances< + LoopScheduler::Interwave, + PipelineVersion::v1>{}); +#endif +#if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances< + LoopScheduler::Default, + PipelineVersion::v2>{}); +#endif + add_device_operation_instances( + instances, + device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp new file mode 100644 index 0000000000..e1c90f0f52 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_add_relu_add_layernorm_impl.hpp @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_layernorm.hpp" + +namespace ck { +namespace profiler { + +template +void host_gemm_layernorm(Tensor& h_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& d0_m_n, + const Tensor& d1_m_n, + const Tensor& gamma_n, + const Tensor& beta_n, + AElementOp a_element_op, + BElementOp b_element_op, + CDEElementOp cde_element_op, + HElementOp h_element_op, + int M, + int N, + AccDataType epsilon = 1e-5) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ReferenceGemm = ck::tensor_operation::host::ReferenceGemm; + + using ReferenceLayernorm = ck::tensor_operation::host::ReferenceLayernorm; + + Tensor e_m_n(HostTensorDescriptor{M, N}); + Tensor c_m_n(HostTensorDescriptor{M, N}); + + auto ref_gemm = ReferenceGemm{}; + auto ref_gemm_invoker = ref_gemm.MakeInvoker(); + + auto ref_gemm_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_gemm_invoker.Run(ref_gemm_argument); + + for(int n = 0; n < N; ++n) + { + for(int m = 0; m < M; ++m) + { + AccDataType e = static_cast(e_m_n(m, n)); + AccDataType d0 = static_cast(d0_m_n(m, n)); + AccDataType d1 = static_cast(d1_m_n(m, n)); + cde_element_op(e, c_m_n(m, n), d0, d1); + e_m_n(m, n) = static_cast(e); + } + } + + ReferenceLayernorm ref_layernorm; + auto ref_layernorm_invoker = ref_layernorm.MakeInvoker(); + + auto ref_layernorm_argument = ref_layernorm.MakeArgument( + e_m_n, gamma_n, beta_n, h_m_n, h_element_op, {M, N}, {1}, epsilon); + ref_layernorm_invoker.Run(ref_layernorm_argument); +} + +template +bool profile_gemm_add_relu_add_layernorm_impl(int do_verification, + int init_method, + bool /*do_log*/, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideD0, + int StrideD1, + int StrideH, + AccDataType epsilon = 1e-5) +{ + auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor({len}, {stride}); + }; + + auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if constexpr(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor d0_m_n(f_host_tensor_descriptor2d(M, N, StrideD0, D0Layout{})); + Tensor d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); + Tensor h_m_n_host(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + break; + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CDEElementOp = AddReluAdd; + using HElementOp = PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + const auto h_element_op = HElementOp{}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm< + ALayout, + BLayout, + ck::Tuple, + HLayout, + ADataType, + BDataType, + ck::Tuple, + GammaDataType, + BetaDataType, + HDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd, + ck::tensor_operation::element_wise::PassThrough>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + host_gemm_layernorm(h_m_n_host, + a_m_k, + b_k_n, + d0_m_n, + d1_m_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + cde_element_op, + h_element_op, + M, + N, + epsilon); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); + DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize()); + DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); + d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + std::string best_op_name; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + + bool pass = true; + int num_kernel = 0; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + {d0_m_n_device_buf.GetDeviceBuffer(), d1_m_n_device_buf.GetDeviceBuffer()}, + gamma_device_buf.GetDeviceBuffer(), + beta_device_buf.GetDeviceBuffer(), + h_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + {StrideD0, StrideD1}, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + ++num_kernel; + + size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get()); + DeviceMem workspace_dev(workspace_sz); + op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer()); + + // re-init E to zero before profiling a kernel + h_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + (sizeof(D0DataType) + sizeof(D1DataType) + sizeof(HDataType)) * M * N + + (sizeof(GammaDataType) + sizeof(BetaDataType)) * N; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + if(time_kernel) + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(ave_time < best_ave_time) + { + best_op_name = op_name; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + h_device_buf.FromDevice(h_m_n.mData.data()); + + pass = pass && ck::utils::check_err( + h_m_n, h_m_n_host, "Error: Incorrect results h_m_n", 1e-2, 1e-2); + } + } + else + { + if(time_kernel) + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + if(num_kernel == 0) + { + std::cout << "Error: No kernel is applicable" << std::endl; + pass = false; + } + else + { + if(time_kernel) + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_op_name << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index bcf25f87e8..d3ab88a167 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -8,6 +8,7 @@ set(PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp profile_gemm_add_multiply.cpp profile_gemm_add_fastgelu.cpp + profile_gemm_add_relu_add_layernorm.cpp profile_gemm_fastgelu.cpp profile_gemm_reduce.cpp profile_batched_gemm.cpp @@ -43,6 +44,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgel target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) @@ -66,5 +68,4 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) - rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/profiler/src/profile_gemm_add_relu_add_layernorm.cpp b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp new file mode 100644 index 0000000000..5cbc3d21f8 --- /dev/null +++ b/profiler/src/profile_gemm_add_relu_add_layernorm.cpp @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_add_relu_add_layernorm" +#define OP_DESC "GEMM+Add+Relu+Add+Layernorm" + +int profile_gemm_add_relu_add_layernorm(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN_MN_MN, // 0 + MK_NK_MN_MN_MN, // 1 + KM_KN_MN_MN_MN, // 2 + KM_NK_MN_MN_MN, // 3 + }; + + enum struct MatrixDataType + { + F32, // 0 + F16, // 1 + BF16, // 2 + }; + + if(argc != 16) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16)\n"); + printf("arg3: matrix layout (0: H[m, n] = Layernorm(Relu(A[m, k] * B[k, n] + D0[m, n]) + D1[m, n]);\n"); + printf(" 1: H[m, n] = Layernorm(Relu(A[m, k] * B[n, k] + D0[m, n]) + D1[m, n]);\n"); + printf(" 2: H[m, n] = Layernorm(Relu(A[k, m] * B[k, n] + D0[m, n]) + D1[m, n]);\n"); + printf(" 3: H[m, n] = Layernorm(Relu(A[k, m] * B[n, k] + D0[m, n]) + D1[m, n]))\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideH\n"); + // clang-format on + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideD0 = std::stoi(argv[13]); + const int StrideD1 = std::stoi(argv[14]); + const int StrideH = std::stoi(argv[15]); + + using F16 = ck::half_t; + using F32 = float; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto d0_type, + auto d1_type, + auto e_mean_var_type, + auto gamma_type, + auto beta_type, + auto h_type, + auto a_layout, + auto b_layout, + auto d0_layout, + auto d1_layout, + auto h_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using D0DataType = decltype(d0_type); + using D1DataType = decltype(d1_type); + using EMeanVarDataType = decltype(e_mean_var_type); + using GammaDataType = decltype(gamma_type); + using BetaDataType = decltype(beta_type); + using HDataType = decltype(h_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using D0Layout = decltype(d0_layout); + using D1Layout = decltype(d1_layout); + using HLayout = decltype(h_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideD0 = ck::is_same_v ? N : M; + const int DefaultStrideD1 = ck::is_same_v ? N : M; + const int DefaultStrideH = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_add_relu_add_layernorm_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideD0 < 0) ? DefaultStrideD0 : StrideD0, + (StrideD1 < 0) ? DefaultStrideD1 : StrideD1, + (StrideH < 0) ? DefaultStrideH : StrideH); + + return pass ? 0 : 1; + }; + + if(data_type == MatrixDataType::F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) + { + return profile(F16{}, + F16{}, + F32{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + Row{}, + Row{}, + Row{}, + Row{}, + Row{}); + } + else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::MK_NK_MN_MN_MN) + { + return profile(F16{}, + F16{}, + F32{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + Row{}, + Col{}, + Row{}, + Row{}, + Row{}); + } + else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::KM_KN_MN_MN_MN) + { + return profile(F16{}, + F16{}, + F32{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + Col{}, + Row{}, + Row{}, + Row{}, + Row{}); + } + else if(data_type == MatrixDataType::F16 && layout == MatrixLayout::KM_NK_MN_MN_MN) + { + return profile(F16{}, + F16{}, + F32{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + F16{}, + Col{}, + Col{}, + Row{}, + Row{}, + Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_add_relu_add_layernorm); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b2e25e4ca7..6f43e52355 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -27,7 +27,7 @@ function(add_gtest_executable TEST_NAME) # suppress gtest warnings target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) target_link_libraries(${TEST_NAME} PRIVATE gtest_main) - add_test(NAME ${TEST_NAME} COMMAND $ ) + add_test(NAME ${TEST_NAME} COMMAND $) rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) endfunction(add_gtest_executable TEST_NAME) @@ -36,6 +36,7 @@ add_subdirectory(space_filling_curve) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) +add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000..c4feb5c564 --- /dev/null +++ b/test/gemm_layernorm/CMakeLists.txt @@ -0,0 +1,7 @@ +add_custom_target(test_gemm_layernorm) + +add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) + +target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + +add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp new file mode 100644 index 0000000000..740c63aa7e --- /dev/null +++ b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "profiler/profile_gemm_add_relu_add_layernorm_impl.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using F16 = ck::half_t; +using F32 = float; +using ck::index_t; + +template +class TestGemmAddReluAddLayernorm : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using D0DataType = std::tuple_element_t<3, Tuple>; + using D1DataType = std::tuple_element_t<4, Tuple>; + using EMeanVarDataType = std::tuple_element_t<5, Tuple>; + using GammaDataType = std::tuple_element_t<6, Tuple>; + using BetaDataType = std::tuple_element_t<7, Tuple>; + using HDataType = std::tuple_element_t<8, Tuple>; + using ALayout = std::tuple_element_t<9, Tuple>; + using BLayout = std::tuple_element_t<10, Tuple>; + using D0Layout = std::tuple_element_t<11, Tuple>; + using D1Layout = std::tuple_element_t<12, Tuple>; + using HLayout = std::tuple_element_t<13, Tuple>; + + void Run() + { + std::vector> lengths = { + {1024, 1024, 1024}, {2048, 640, 640}, {1, 1, 1}}; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideD0 = 0; + int StrideD1 = ck::is_same_v ? N : M; + int StrideH = ck::is_same_v ? N : M; + + bool success = ck::profiler::profile_gemm_add_relu_add_layernorm_impl( + true, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideH); + + EXPECT_TRUE(success); + } + } +}; + +using KernelTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmAddReluAddLayernorm, KernelTypes); +TYPED_TEST(TestGemmAddReluAddLayernorm, Test_FP16) { this->Run(); } diff --git a/test/normalization/CMakeLists.txt b/test/normalization/CMakeLists.txt index 456423f25d..a5d7fb2982 100644 --- a/test/normalization/CMakeLists.txt +++ b/test/normalization/CMakeLists.txt @@ -1,17 +1,16 @@ -add_custom_target(test_layernorm) +add_custom_target(test_normalization) add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) -add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) - +add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) -add_dependencies(test_layernorm test_layernorm2d_fp32) -add_dependencies(test_layernorm test_layernorm2d_fp16) -add_dependencies(test_layernorm test_groupnorm_fp16) -add_dependencies(test_layernorm test_groupnorm_fp32) +add_dependencies(test_normalization test_layernorm2d_fp32) +add_dependencies(test_normalization test_layernorm2d_fp16) +add_dependencies(test_normalization test_groupnorm_fp16) +add_dependencies(test_normalization test_groupnorm_fp32) From 0ac0f51ad64b1c82bc3417e08123f7d0fac5c6c5 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 10 Feb 2023 11:00:37 -0800 Subject: [PATCH 10/22] enable batched_gemm_softmax_bf16 tests (#582) --- .../test_batched_gemm_softmax_gemm_permute_bf16.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp index 4a775e6d92..defe361240 100644 --- a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp +++ b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp @@ -27,7 +27,7 @@ using KernelTypes = ::testing::Types< TYPED_TEST_SUITE(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, KernelTypes); -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Test_BF16) { this->Run(); } +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16) { this->Run(); } TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_PadM) { @@ -96,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Test_BF16_OddO) this->Run(); } -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16_IrregularK) +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16_IrregularK) { this->lengths_ = std::vector>{{256, 256, 160, 160, 1, 16}, {256, 64, 160, 64, 1, 16}, @@ -109,7 +109,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF1 this->Run(); } -TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, DISABLED_Bench_BF16) +TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteBF16, Bench_BF16) { this->lengths_ = std::vector>{ {256, 256, 64, 64, 48, 16}, From 8f42780fd6bc943e1f9168807f0bb84a7bfa8c94 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Mon, 13 Feb 2023 17:06:24 +0100 Subject: [PATCH 11/22] GroupedGEMM more bigger tiles. (#577) * Adding more bigger tiles. * Remove failing instance. * Remove instances which that don't improve perf. --------- Co-authored-by: Adam Osewski Co-authored-by: zjing14 --- ...ouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp | 13 ++++++------- ...ouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp | 12 +++++++----- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp index a88bf6042b..a93cb7fc84 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp @@ -64,16 +64,15 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_irregular_tile_instances = st //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp index 9c578319cb..2ace1b2432 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -61,15 +61,17 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_irregular_tile_instances = st //###################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //###################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 32, 8, 8, 16, 16, 1, 1, S<4, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 32, 8, 8, 32, 32, 1, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, DeviceGroupedGemm_Xdl< Row, Col, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on From 06f1fc864c47f68fadb9c92f8f762a7bf83c15f6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 14 Feb 2023 16:06:24 -0800 Subject: [PATCH 12/22] Remove the workaround for bf16 attention tests. (#586) * remove workanround in bf16 attention test * clean up another workaround --- include/ck/ck.hpp | 7 ------- ...tched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp | 8 -------- ...gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 8 -------- 3 files changed, 23 deletions(-) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 312d53d366..ffd7e74f12 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -168,13 +168,6 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 0 -// workaround: a BF16 attention kernel for gfx908 is likely affected by a compiler issue -#ifdef __gfx908__ -#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 1 -#else // __gfx90a__, ... -#define CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE 0 -#endif // __gfx908__ - // flag to enable (1) or disable (0) the debugging output in some kernels #define DEBUG_LOG 0 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp index da93f63538..6a6f19d71e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -1077,14 +1077,6 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle } } // end gemm1 - // workaround compiler issue; see ck/ck.hpp - if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && - is_same_v && MPerBlock == 256 && NPerBlock == 128 && - Gemm1NPerBlock == 128) - { - __builtin_amdgcn_sched_barrier(0); - } - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp index fec360b7fa..ce39c4967b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -879,14 +879,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle } } // end gemm1 - // workaround compiler issue; see ck/ck.hpp - if constexpr(CK_WORKAROUND_SWDEV_XXXXXX_BF16_ATTEN_FWD_GFX908_ISSUE == 1 && - is_same_v && MPerBlock == 256 && NPerBlock == 128 && - Gemm1NPerBlock == 128) - { - __builtin_amdgcn_sched_barrier(0); - } - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); From e9fd122889f76ed050c80315cf73cda60f0ad248 Mon Sep 17 00:00:00 2001 From: Adam Osewski <19374865+aosewski@users.noreply.github.com> Date: Wed, 15 Feb 2023 18:16:47 +0100 Subject: [PATCH 13/22] Conv3D FWD BWD WRW fp16 fp32 client examples (#559) * Conv3d bwd weight client example. * Update year in license * Convolution bwd data 3D fp16/fp32 client example. * Client example for convnd fwd fp16 fp32 * clang-format * Review remarks. * Fix compiler err. * Update data layout to standard one. * Add conv 3d fwd NDHWGC instances * clang-format * Conv3d fwd NDHWGC instances. --------- Co-authored-by: Adam Osewski Co-authored-by: zjing14 --- .../11_grouped_conv_bwd_weight/CMakeLists.txt | 9 +- ...ouped_conv2d_bwd_weight.cpp => common.hpp} | 128 +++++--- .../grouped_conv2d_bwd_weight_fp16.cpp | 41 +++ .../grouped_conv3d_bwd_weight_fp16.cpp | 53 +++ .../grouped_conv3d_bwd_weight_fp32.cpp | 53 +++ .../15_convnd_bwd_data/CMakeLists.txt | 5 + client_example/15_convnd_bwd_data/common.hpp | 233 ++++++++++++++ .../conv3d_bwd_data_fp16.cpp | 42 +++ .../conv3d_bwd_data_fp32.cpp | 42 +++ client_example/16_convnd_fwd/CMakeLists.txt | 5 + client_example/16_convnd_fwd/common.hpp | 304 ++++++++++++++++++ .../16_convnd_fwd/conv3d_fwd_fp16.cpp | 44 +++ .../16_convnd_fwd/conv3d_fwd_fp32.cpp | 44 +++ .../run_gemm_add_multiply_example.inc | 5 +- .../gpu/grouped_convolution_forward.hpp | 82 +++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ...xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp | 129 ++++++++ ..._xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp | 129 ++++++++ ..._xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp | 128 ++++++++ ...xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp | 125 +++++++ 20 files changed, 1565 insertions(+), 41 deletions(-) rename client_example/11_grouped_conv_bwd_weight/{grouped_conv2d_bwd_weight.cpp => common.hpp} (59%) create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp create mode 100644 client_example/15_convnd_bwd_data/CMakeLists.txt create mode 100644 client_example/15_convnd_bwd_data/common.hpp create mode 100644 client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp create mode 100644 client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp create mode 100644 client_example/16_convnd_fwd/CMakeLists.txt create mode 100644 client_example/16_convnd_fwd/common.hpp create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp diff --git a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt index 3e3f667766..761e0de95a 100644 --- a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt +++ b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,2 +1,7 @@ -add_executable(client_grouped_conv2d_bwd_weight grouped_conv2d_bwd_weight.cpp) -target_link_libraries(client_grouped_conv2d_bwd_weight PRIVATE composable_kernel::device_operations) +add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp) +add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) +add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp) + +target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp b/client_example/11_grouped_conv_bwd_weight/common.hpp similarity index 59% rename from client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp rename to client_example/11_grouped_conv_bwd_weight/common.hpp index 1ecc856895..a906263333 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -13,27 +13,8 @@ #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using OutDataType = ck::half_t; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 32; -static constexpr ck::index_t N = 256; -static constexpr ck::index_t K = 192; -static constexpr ck::index_t C = 192; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; -static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; - struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -50,22 +31,93 @@ struct SimpleDeviceMem void* p_mem_; }; -int main() +template +std::size_t GetFlops(ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::array& output_spatial_lengths, + const std::array& filter_spatial_lengths) { - std::array input_spatial_lengths{Hi, Wi}; - std::array filter_spatial_lengths{Y, X}; - std::array output_spatial_lengths{Ho, Wo}; + // 2 * G * N * K * C * * + return static_cast(2) * G * N * K * C * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} - std::array conv_filter_strides{1, 1}; - std::array conv_filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; +template +std::size_t GetInputByte(ck::index_t G, + ck::index_t N, + ck::index_t C, + const std::array& input_spatial_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * (G * N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies<>())); +} + +template +std::size_t GetWeightByte(ck::index_t G, + ck::index_t K, + ck::index_t C, + const std::array& filter_spatial_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * (G * K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies<>())); +} + +template +std::size_t GetOutputByte(ck::index_t G, + ck::index_t N, + ck::index_t K, + const std::array& output_spatial_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies())); +} + +template +bool run_grouped_conv_bwd_weight( + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) +{ ck::index_t split_k = 2; - - SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + SimpleDeviceMem in(GetInputByte(G, N, C, input_spatial_lengths)); + SimpleDeviceMem wei(GetWeightByte(G, K, C, filter_spatial_lengths)); + SimpleDeviceMem out(GetOutputByte(G, N, K, output_spatial_lengths)); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeightRun(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * G * N * Ho * Wo * K; + std::size_t flop = + GetFlops(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); + std::size_t num_bytes = + GetInputByte(G, N, C, input_spatial_lengths) + + GetWeightByte(G, K, C, filter_spatial_lengths) + + GetOutputByte(G, N, K, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -149,7 +203,7 @@ int main() if(best_op_id < 0) { std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; + return false; } std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops @@ -187,4 +241,6 @@ int main() std::cout << "Done" << std::endl; } + + return true; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp new file mode 100644 index 0000000000..1903bd95b6 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_grouped_conv_bwd_weight( + G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp new file mode 100644 index 0000000000..2f2b5d4e21 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_bwd_weight(G, + N, + K, + C, + {Di, Hi, Wi}, + {Z, Y, X}, + {Do, Ho, Wo}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp new file mode 100644 index 0000000000..796311d231 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_bwd_weight(G, + N, + K, + C, + {Di, Hi, Wi}, + {Z, Y, X}, + {Do, Ho, Wo}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 0000000000..8a60a71674 --- /dev/null +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) +add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) + +target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/15_convnd_bwd_data/common.hpp b/client_example/15_convnd_bwd_data/common.hpp new file mode 100644 index 0000000000..9799fb73a5 --- /dev/null +++ b/client_example/15_convnd_bwd_data/common.hpp @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +std::size_t GetFlops(ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& output_spatial_lengths, + const std::vector& weights_spatial_lengths) +{ + // 2 * N * K * C * * + + return static_cast(2) * N * K * C * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::begin(weights_spatial_lengths), + std::end(weights_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(ck::index_t N, ck::index_t C, const std::vector& input_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + return sizeof(InDataType) * N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(ck::index_t K, ck::index_t C, const std::vector& weights_spatial_lengths) +{ + // sizeof(WeiDataType) * (K * C * ) + + return sizeof(WeiDataType) * K * C * + std::accumulate(std::begin(weights_spatial_lengths), + std::end(weights_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(ck::index_t N, ck::index_t K, const std::vector& output_spatial_lengths) +{ + // sizeof(OutDataType) * (N * K * ); + return sizeof(OutDataType) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_conv_bwd_data(ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& in_spatial_lengths, + const std::vector& wei_spatial_lengths, + const std::vector& out_spatial_lengths) +{ + std::size_t in_mem_size = GetInputByte(N, C, in_spatial_lengths); + std::size_t wei_mem_size = GetWeightByte(K, C, wei_spatial_lengths); + std::size_t out_mem_size = GetOutputByte(N, K, out_spatial_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::vector filter_strides(NumDimSpatial, 1); + std::vector filter_dilations(NumDimSpatial, 1); + std::vector input_left_pads(NumDimSpatial, 1); + std::vector input_right_pads(NumDimSpatial, 1); + + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + std::size_t flop = GetFlops(N, K, C, out_spatial_lengths, wei_spatial_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp new file mode 100644 index 0000000000..5210567241 --- /dev/null +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_conv_bwd_data(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp new file mode 100644 index 0000000000..441bdfe7be --- /dev/null +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_conv_bwd_data(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt new file mode 100644 index 0000000000..e2580a370c --- /dev/null +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) +add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + +target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp new file mode 100644 index 0000000000..a6bb5aa65b --- /dev/null +++ b/client_example/16_convnd_fwd/common.hpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t +GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 2), rend(wei_lengths)); + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 2), rend(wei_strides)); + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp new file mode 100644 index 0000000000..10f914bbee --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp new file mode 100644 index 0000000000..43c98f1e9b --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc index 4f7a8a4ca7..e1b2bccfe1 100644 --- a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc +++ b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc @@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); @@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi if(!device_op.IsSupportedArgument(argument)) { - std::cout << "wrong! this device_op instance does not support this problem" << std::endl; - return true; + std::cout << "wrong! this device_op instance does not support this problem" << std::endl; + return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ee38b73827..a8df7f0d5b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -244,6 +244,63 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( PassThrough, PassThrough>>>& instances); +// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( + std::vector>>& instances); + template && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 78eedca5f7..90efc09ee7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -3,4 +3,9 @@ add_instance_library(device_grouped_conv3d_fwd_instance device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp + + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..8c38493735 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..487cd22721 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..d497cd57ed --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp new file mode 100644 index 0000000000..2e53fbbda5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances = std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 0cfda84d0576b7dc271755b3da214a4ab55fa33d Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Thu, 16 Feb 2023 01:50:51 +0800 Subject: [PATCH 14/22] [Navi3x] Add Device Operations (#567) * wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * format * Add arch limitation to all wmma examples * fix bug: example30 input conv args --- example/01_gemm/CMakeLists.txt | 8 +- example/02_gemm_bilinear/CMakeLists.txt | 3 + .../gemm_bilinear_wmma_fp16.cpp | 304 ++++++ .../CMakeLists.txt | 4 + .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 431 ++++++++ .../CMakeLists.txt | 3 + .../30_grouped_conv_fwd_multiple_d/common.hpp | 2 +- .../common_wmma.hpp | 355 +++++++ ...ouped_conv_fwd_bias_relu_add_wmma_fp16.cpp | 26 + ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 286 +++++ ...d_contraction_multiple_d_wmma_cshuffle.hpp | 991 ++++++++++++++++++ .../device_gemm_multiple_d_wmma_cshuffle.hpp | 654 ++++++++++++ ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 850 +++++++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 937 +++++++++++++++++ 14 files changed, 4850 insertions(+), 4 deletions(-) create mode 100644 example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp create mode 100644 example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp create mode 100644 example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp create mode 100644 example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp create mode 100644 example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index ecff4298eb..7f8fdf35f4 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -38,7 +38,9 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) -add_custom_target(example_gemm_wmma) -add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) -add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +if(GPU_TARGETS MATCHES "gfx1100") + add_custom_target(example_gemm_wmma) + add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) + add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) +endif() diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 10ec0f1a71..1343a814ad 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1 +1,4 @@ add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +endif() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp new file mode 100644 index 0000000000..ff99bf4641 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 256, + 128, + 256, + 8, + 8, + 16, + 16, + 4, + 4, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 40470f27d4..c74294feb0 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1 +1,5 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) + +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) +endif() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp new file mode 100644 index 0000000000..30ad38a566 --- /dev/null +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -0,0 +1,431 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Add = ck::tensor_operation::element_wise::Add; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F16; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F16; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 1; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Add; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceOpInstanceKKNN = + ck::tensor_operation::device::DeviceBatchedContractionMultipleD_Wmma_CShuffle, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = + false> +struct ReferenceContraction_G2_M2_N2_K1 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_gs_ms_ks_{a_gs_ms_ks}, + b_gs_ns_ks_{b_gs_ns_ks}, + e_gs_ms_ns_{e_gs_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_gs_ms_ks_; + const Tensor& b_gs_ns_ks_; + Tensor& e_gs_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_G2_M2_N2_K1::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto g0, auto g1, auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_gs_ms_ks_.mDesc.GetLengths()[4]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, + ck::type_convert(arg.a_gs_ms_ks_(g0, g1, m0, m1, k0))); + arg.b_element_op_( + v_b, + ck::type_convert(arg.b_gs_ns_ks_(g0, g1, n0, n1, k0))); + + v_acc += v_a * v_b; + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_gs_ms_ns_(g0, g1, m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_gs_ms_ns_.mDesc.GetLengths()[0], + arg.e_gs_ms_ns_.mDesc.GetLengths()[1], + arg.e_gs_ms_ns_.mDesc.GetLengths()[2], + arg.e_gs_ms_ns_.mDesc.GetLengths()[3], + arg.e_gs_ms_ns_.mDesc.GetLengths()[4], + arg.e_gs_ms_ns_.mDesc.GetLengths()[5])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_gs_ms_ks, + const Tensor& b_gs_ns_ks, + Tensor& e_gs_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + a_gs_ms_ks, b_gs_ns_ks, e_gs_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_G2_M2_N2_K1" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + ck::index_t G0 = 1; + ck::index_t G1 = 2; + + ck::index_t M0 = 4; + ck::index_t M1 = 128; + + ck::index_t N0 = 16; + ck::index_t N1 = 256; + + ck::index_t K0 = 2048; + + // A[G0, G1, M0, M1, K0] + std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; + std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; + // B[G0, G1, N0, N1, K0] + std::vector b_gs_ns_ks_lengths{G0, G1, N0, N1, K0}; + std::vector b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1}; + + // D[G0, G1, M0, N0, M1, N1] + std::vector d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1}; + // E[G0, G1, M0, N0, M1, N1] + std::vector e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1}; + std::vector e_gs_ms_ns_strides{ + G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); + Tensor e_gs_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + Tensor e_gs_ms_ns_device_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b_gs_ns_ks: " << b_gs_ns_ks.mDesc << std::endl; + std::cout << "d_gs_ms_ns: " << d_gs_ms_ns.mDesc << std::endl; + std::cout << "e_gs_ms_ns: " << e_gs_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_gs_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_gs_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * + e_gs_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b_device_buf.ToDevice(b_gs_ns_ks.mData.data()); + d_device_buf.ToDevice(d_gs_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + std::array, 1>{d_gs_ms_ns_lengths}, + std::array, 1>{d_gs_ms_ns_strides}, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t G = + ck::accumulate_n(e_gs_ms_ns_lengths.begin(), NumDimG, 1, std::multiplies<>{}); + + ck::index_t M = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{}); + std::cout << "GMNK=" << G << ", " << M << ", " << N << ", " << K << std::endl; + std::size_t flop = std::size_t(2) * G * M * N * K; + std::size_t num_btype = sizeof(ADataType) * G * M * K + sizeof(BDataType) * G * K * N + + sizeof(DDataType) * G * M * N + sizeof(EDataType) * G * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_gs_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_G2_M2_N2_K1; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_gs_ms_ks, b_gs_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t g0 = 0; g0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[0]; ++g0) + { + for(size_t g1 = 0; g1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[1]; ++g1) + { + for(size_t m0 = 0; m0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[2]; ++m0) + { + for(size_t m1 = 0; m1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[3]; ++m1) + { + for(size_t n0 = 0; n0 < e_gs_ms_ns_host_result.mDesc.GetLengths()[4]; ++n0) + { + for(size_t n1 = 0; n1 < e_gs_ms_ns_host_result.mDesc.GetLengths()[5]; + ++n1) + { + cde_element_op(e_gs_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + c_ms_ns_host_result(g0, g1, m0, m1, n0, n1), + d_gs_ms_ns(g0, g1, m0, m1, n0, n1)); + } + } + } + } + } + } + + return ck::utils::check_err(e_gs_ms_ns_device_result, e_gs_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 61b2b2f6f3..acf9bcdb46 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -16,6 +16,9 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) endif() # USE_BITINT_EXTENSION_INT4 +if(GPU_TARGETS MATCHES "gfx1100") + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +endif() add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/common.hpp b/example/30_grouped_conv_fwd_multiple_d/common.hpp index d6d6dd6ff1..e7c6ed9b93 100644 --- a/example/30_grouped_conv_fwd_multiple_d/common.hpp +++ b/example/30_grouped_conv_fwd_multiple_d/common.hpp @@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp new file mode 100644 index 0000000000..eb6975a6d8 --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp @@ -0,0 +1,355 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" + +using BF16 = ck::bhalf_t; +using FP16 = ck::half_t; +using FP32 = float; +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +using I4 = ck::int4_t; +#endif +using I8 = std::int8_t; +using I32 = std::int32_t; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +struct CommonLayoutSetting +{ + using InputLayout = InputLay; + using WeightLayout = WeightLay; + using OutputLayout = OutputLay; +}; + +template +struct CommonLayoutSettingSelector; + +namespace ctl = ck::tensor_layout::convolution; + +template <> +struct CommonLayoutSettingSelector<1> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<2> final + : CommonLayoutSetting +{ +}; + +template <> +struct CommonLayoutSettingSelector<3> final + : CommonLayoutSetting +{ +}; + +template +using InputLayout = typename CommonLayoutSettingSelector::InputLayout; + +template +using WeightLayout = typename CommonLayoutSettingSelector::WeightLayout; + +template +using OutputLayout = typename CommonLayoutSettingSelector::OutputLayout; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +}; + +#define DefaultConvParam \ + ck::utils::conv::ConvParam \ + { \ + 2, 32, 2, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \ + } + +inline void print_help_msg() +{ + std::cerr << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=no, 1=yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; +} + +inline bool parse_cmd_args(int argc, + char* argv[], + ExecutionConfig& config, + ck::utils::conv::ConvParam& conv_param) +{ + constexpr int num_execution_config_args = + 3; // arguments for do_verification, init_method, time_kernel + constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_ + + constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args; + constexpr int threshold_to_catch_all_args = + threshold_to_catch_partial_args + num_conv_param_leading_args; + + if(argc == 1) + { + // use default + } + // catch only ExecutionConfig arguments + else if(argc == threshold_to_catch_partial_args) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + // catch both ExecutionConfig & ConvParam arguments + else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0)) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + const ck::index_t num_dim_spatial = std::stoi(argv[4]); + conv_param = ck::utils::conv::parse_conv_param( + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); + } + else + { + print_help_msg(); + return false; + } + + return true; +} + +inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.C_, conv_param.input_spatial_lengths_[0]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.G_ * conv_param.C_ // wi + }); + + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.C_, + conv_param.input_spatial_lengths_[0], + conv_param.input_spatial_lengths_[1], + conv_param.input_spatial_lengths_[2]}, + { + conv_param.C_, // g + conv_param.input_spatial_lengths_[0] * conv_param.input_spatial_lengths_[1] * + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // n + 1, // c + conv_param.input_spatial_lengths_[1] * conv_param.input_spatial_lengths_[2] * + conv_param.G_ * conv_param.C_, // di + conv_param.input_spatial_lengths_[2] * conv_param.G_ * conv_param.C_, // hi + conv_param.G_ * conv_param.C_ // wi + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.K_, conv_param.C_, conv_param.filter_spatial_lengths_[0]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.C_, // k + 1, // c + conv_param.C_ // x + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y + conv_param.C_ // x + }); + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.K_, + conv_param.C_, + conv_param.filter_spatial_lengths_[0], + conv_param.filter_spatial_lengths_[1], + conv_param.filter_spatial_lengths_[2]}, + { + conv_param.K_ * conv_param.filter_spatial_lengths_[0] * + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // g + conv_param.filter_spatial_lengths_[0] * conv_param.filter_spatial_lengths_[1] * + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // k + 1, // c + conv_param.filter_spatial_lengths_[1] * conv_param.filter_spatial_lengths_[2] * + conv_param.C_, // z + conv_param.filter_spatial_lengths_[2] * conv_param.C_, // y + conv_param.C_ // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_bias_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + 0, // k + 1, // c + 0 // x + }); + case 2: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // ho + 0 // wo + }); + case 3: + return HostTensorDescriptor({conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + 0, // n + 1, // k + 0, // z + 0, // y + 0 // x + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} + +inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvParam& conv_param) +{ + + switch(conv_param.num_dim_spatial_) + { + case 1: + return HostTensorDescriptor( + {conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.G_ * conv_param.K_ // wo + }); + case 2: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + + case 3: + return HostTensorDescriptor( + {conv_param.G_, + conv_param.N_, + conv_param.K_, + conv_param.output_spatial_lengths_[0], + conv_param.output_spatial_lengths_[1], + conv_param.output_spatial_lengths_[2]}, + { + conv_param.K_, // g + conv_param.output_spatial_lengths_[0] * conv_param.output_spatial_lengths_[1] * + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // n + 1, // k + conv_param.output_spatial_lengths_[1] * conv_param.output_spatial_lengths_[2] * + conv_param.G_ * conv_param.K_, // do + conv_param.output_spatial_lengths_[2] * conv_param.G_ * conv_param.K_, // ho + conv_param.G_ * conv_param.K_ // wo + }); + } + + throw std::runtime_error("unsuppored # dim spatial"); +} diff --git a/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp new file mode 100644 index 0000000000..9d1d257a28 --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common_wmma.hpp" + +// kernel data types +using InKernelDataType = FP16; +using WeiKernelDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using BiasKernelDataType = FP16; +using ResidualKernelDataType = FP16; +using OutKernelDataType = FP16; + +// tensor data types +using InUserDataType = InKernelDataType; +using WeiUserDataType = WeiKernelDataType; +using OutUserDataType = OutKernelDataType; + +using InElementOp = PassThrough; +using WeiElementOp = PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd; + +#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc new file mode 100644 index 0000000000..8161b1088a --- /dev/null +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -0,0 +1,286 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +template +struct LayoutSetting +{ + using BiasLayout = BiasLay; + using ResidualLayout = ResidualLay; +}; + +template +struct LayoutSettingSelector; + +template <> +struct LayoutSettingSelector<1> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<2> final : LayoutSetting +{ +}; + +template <> +struct LayoutSettingSelector<3> final : LayoutSetting +{ +}; + +template +using BiasLayout = typename LayoutSettingSelector::BiasLayout; + +template +using ResidualLayout = typename LayoutSettingSelector::ResidualLayout; + +template +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< + NDimSpatial, + InputLayout, + WeightLayout, + ck::Tuple, ResidualLayout>, + OutputLayout, + InKernelDataType, + WeiKernelDataType, + ck::Tuple, + OutKernelDataType, + AccDataType, + CShuffleDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 16, // MPerWMMA + 16, // NPerWMMA + 4, // MRepeat + 2, // NRepeat + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + true, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + true, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +template +using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +template +bool run_grouped_conv_fwd_bias_relu_add(const ExecutionConfig& config, + const ck::utils::conv::ConvParam& conv_param) +{ + static_assert(1 <= NDimSpatial && NDimSpatial <= 3, "Unsupported NDimSpatial"); + + const auto in_g_n_c_wis_desc = make_input_descriptor(conv_param); + const auto wei_g_k_c_xs_desc = make_weight_descriptor(conv_param); + const auto bias_g_n_k_wos_desc = make_bias_descriptor(conv_param); + const auto out_g_n_k_wos_desc = make_output_descriptor(conv_param); + + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor bias(bias_g_n_k_wos_desc); + Tensor residual(bias_g_n_k_wos_desc); + Tensor out_host(out_g_n_k_wos_desc); + Tensor out_device(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(config.init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InKernelDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiKernelDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem bias_device_buf(sizeof(OutKernelDataType) * bias.mDesc.GetElementSpaceSize()); + DeviceMem residual_device_buf(sizeof(OutKernelDataType) * residual.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutKernelDataType) * out_device.mDesc.GetElementSpaceSize()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor in_converted(in); + const Tensor wei_converted(wei); + const Tensor bias_converted(bias); + const Tensor residual_converted(residual); + + in_device_buf.ToDevice(in_converted.mData.data()); + wei_device_buf.ToDevice(wei_converted.mData.data()); + bias_device_buf.ToDevice(bias_converted.mData.data()); + residual_device_buf.ToDevice(residual_converted.mData.data()); +#else + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + residual_device_buf.ToDevice(residual.mData.data()); +#endif + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array d0_g_n_k_wos_lengths{}; + std::array d0_g_n_k_wos_strides{}; + std::array d1_g_n_k_wos_lengths{}; + std::array d1_g_n_k_wos_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d0_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d0_g_n_k_wos_strides); + copy(bias_g_n_k_wos_desc.GetLengths(), d1_g_n_k_wos_lengths); + copy(bias_g_n_k_wos_desc.GetStrides(), d1_g_n_k_wos_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // do Conv + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = + conv.MakeArgument(in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + std::array{bias_device_buf.GetDeviceBuffer(), + residual_device_buf.GetDeviceBuffer()}, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, 2>{ + {d0_g_n_k_wos_lengths, d1_g_n_k_wos_lengths}}, + std::array, 2>{ + {d0_g_n_k_wos_strides, d1_g_n_k_wos_strides}}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = conv_param.GetFlops(); + std::size_t num_btype = conv_param.GetByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(config.do_verification) + { + Tensor c_host(out_g_n_k_wos_desc); + + auto ref_conv = HostConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + c_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + + // TODO: implement elementwise operation for host + out_host.ForEach([&](auto&, auto idx) { + OutElementOp{}(out_host(idx), c_host(idx), bias(idx), residual(idx)); + }); + + out_device_buf.FromDevice(out_device.mData.data()); + +#ifdef BUILD_INT4_EXAMPLE + const Tensor out_device_converted(out_device); + + return ck::utils::check_err( + out_device_converted, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#else + return ck::utils::check_err( + out_device, out_host, "Error: incorrect results!", 1e-5f, 1e-4f); +#endif + } + + return true; +} + +bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) +{ + ExecutionConfig config; + ck::utils::conv::ConvParam conv_param = DefaultConvParam; + + if(!parse_cmd_args(argc, argv, config, conv_param)) + { + return false; + } + + switch(conv_param.num_dim_spatial_) + { + case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + } + + return false; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..b1a78dc99b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,991 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); + using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& b_gs_ns_ks_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_strides, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{}, + b_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + }); + + a_grid_desc_m_k_ = + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_n_k_ = + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + + ds_grid_desc_m_n_ = + DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); + + e_grid_desc_m_n_ = + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + // Batch Offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + DeviceOp::AGridDesc_K0_M_K1, + DeviceOp::BGridDesc_K0_N_K1, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + typename GridwiseOp::DefaultBlock2CTileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..66c4de7f05 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,654 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } +#ifdef ENABLE_COLMAJOR + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } +#endif + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto e_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& Ms, + const std::array& Ns, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(Ms[i], Ns[i], DsStride[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideE); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + if(GridwiseOp::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_ctile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", " + << arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl; + } +#endif + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + true>; // Last Option is W/O + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + false>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + return GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..e245902b6c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,850 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; +}; + +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// Assume: +// AK1 == BK1 +template +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + : public DeviceGroupedConvFwdMultipleD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + static constexpr auto conv_to_gemm_transformer = + TransformConvFwdToGemm{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = remove_cvref_t({}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; + using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t({}, {}))>; + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK1 = K1; + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK1 = K1; + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); + using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{})); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + K0PerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + index_t M01, + index_t N01, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, + b_g_k_c_xs_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, + e_g_n_k_wos_strides)}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + }); + + // D desc + ds_grid_desc_m_n_ = + DeviceOp::MakeDsGridDescriptor_M_N(ds_g_n_k_wos_lengths, ds_g_n_k_wos_strides); + + // populate desc for Ds/E + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx1100") + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100644 index 0000000000..2ce4d8feb3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,937 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + // printf("entry kernel launch"); + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + // printf("before compute_ptr_offset call"); + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + DsPointer p_ds_grid_grp; + + // printf("before allocate pointer d"); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + // printf("before entry"); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_mupltipe_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__)) + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + + GridwiseOp::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +template < // DataType Family + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename DsDataType, + typename EDataType, + // InMemory Data Descriptor + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename DsGridDesc_M_N, + typename EGridDesc_M_N, + // ElementwiseOp Family + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CDEElementwiseOperation, + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + // Tiling Family + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerWmma, + index_t NPerWmma, + index_t K1Value, + index_t MRepeat, + index_t NRepeat, + // ThreadCluster Family + index_t BlockSize, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMRepeatPerShuffle, + index_t CShuffleNRepeatPerShuffle, + typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, + index_t NumGemmKPrefetchStage = 1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0perblock_mperblock_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0perblock_nperblock_k1; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0perblock_mperblock_k1 = + GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0perblock_nperblock_k1 = + GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const EGridDesc_M_N& e_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // printf("safe entry"); + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + constexpr auto max_lds_align = K1; + constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), +/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0perblock_mperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0perblock_nperblock_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0perblock_nperblock_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + +/*******************************************************************************/ + // GEMM + constexpr auto WmmaK = 16; + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); + + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0perblock_mperblock_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0perblock_nperblock_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); +/*******************************************************************************/ + //printf("safe 1"); + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // bool ThreadTransferSrcResetCoordinateAfterRun, + Sequence> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_cde_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_shuffle_block_copy_lds_to_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck From 6a6163a3d1d3e9336092c1d3ae152c2da41fc094 Mon Sep 17 00:00:00 2001 From: rocking5566 Date: Thu, 16 Feb 2023 01:59:35 +0800 Subject: [PATCH 15/22] Improve normalization (#580) * Sync the order of type string with template parameter * Add more instances * Check the vector size and remove redundant var * Extract var to static, prepare to separate sweep once kernel * Separate sweeponce flow and optimize the flow * 1. Rename AccDatatype in normalization to computeData 2. Rename AccElementwiseOperation to YElementwiseOperation in normalization * Remove useless code * Update naive variance kernel * Refine string * Fix typo * Support naive variance for device_normalization * Check the blocksize * Share the VGPR of x and y * Share the VGPR of gamma and beta * Add more instances * Support fp16 sqrt for experiment * Add CHANGELOG * Fix typo * clang-format --- CHANGELOG.md | 2 +- client_example/05_layernorm/layernorm2d.cpp | 14 +- example/27_layernorm/layernorm_blockwise.cpp | 16 +- .../42_groupnorm/groupnorm_sigmoid_fp16.cpp | 14 +- .../gpu/device/device_normalization.hpp | 14 +- .../device/impl/device_normalization_impl.hpp | 152 ++---- .../gridwise_normalization_naive_variance.hpp | 461 ++++++++++++------ .../grid/gridwise_normalization_selector.hpp | 195 ++++++++ ...ridwise_normalization_welford_variance.hpp | 287 +++++++---- include/ck/utility/math_v2.hpp | 10 + .../device_normalization_f16_instance.cpp | 27 +- .../device_normalization_f32_instance.cpp | 27 +- .../profiler/profile_layernorm_impl.hpp | 10 +- test/normalization/test_groupnorm_fp16.cpp | 14 +- test/normalization/test_groupnorm_fp32.cpp | 14 +- test/normalization/test_layernorm2d_fp16.cpp | 14 +- test/normalization/test_layernorm2d_fp32.cpp | 14 +- 17 files changed, 826 insertions(+), 459 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index 2c3215ae44..23e7fb6274 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,7 @@ Full documentation for Composable Kernel is not yet available. - Fixed grouped ConvBwdWeight test case failure (#524). ### Optimizations -- Optimized ... +- Improve proformance of normalization kernel ### Added - Added user tutorial (#563). diff --git a/client_example/05_layernorm/layernorm2d.cpp b/client_example/05_layernorm/layernorm2d.cpp index adb41171e1..856a4cc219 100644 --- a/client_example/05_layernorm/layernorm2d.cpp +++ b/client_example/05_layernorm/layernorm2d.cpp @@ -12,12 +12,12 @@ #include "ck/library/tensor_operation_instance/gpu/normalization.hpp" -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using AccDataType = float; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 2; constexpr int NumReduceDim = 1; @@ -54,7 +54,7 @@ int main(int argc, char* argv[]) using DeviceOp = ck::tensor_operation::device::DeviceNormalization; diff --git a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp index e62001d669..35c7c054e0 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_fp16.cpp @@ -23,11 +23,11 @@ constexpr int Rank = 5; constexpr int NumReduceDim = 3; -using XDataType = ck::half_t; -using GammaDataType = ck::half_t; -using BetaDataType = ck::half_t; -using YDataType = ck::half_t; -using AccDataType = float; +using XDataType = ck::half_t; +using GammaDataType = ck::half_t; +using BetaDataType = ck::half_t; +using YDataType = ck::half_t; +using ComputeDataType = float; struct YElementOp { @@ -50,7 +50,7 @@ using DeviceInstance = ck::tensor_operation::device::DeviceNormalizationImpl; ReferenceInstance ref; diff --git a/include/ck/tensor_operation/gpu/device/device_normalization.hpp b/include/ck/tensor_operation/gpu/device/device_normalization.hpp index ec17ec3d18..03601ce831 100644 --- a/include/ck/tensor_operation/gpu/device/device_normalization.hpp +++ b/include/ck/tensor_operation/gpu/device/device_normalization.hpp @@ -14,9 +14,9 @@ namespace device { template struct DeviceNormalization : public BaseOperator @@ -35,7 +35,7 @@ struct DeviceNormalization : public BaseOperator void* p_y, void* p_savedMean, void* p_savedInvVar, - AccElementwiseOperation acc_elementwise_op) = 0; + YElementwiseOperation y_elementwise_op) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; @@ -43,17 +43,17 @@ struct DeviceNormalization : public BaseOperator template using DeviceNormalizationPtr = std::unique_ptr>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp index 8cc223a886..bb62332d1a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp @@ -10,46 +10,11 @@ #include "ck/tensor_operation/gpu/device/device_normalization.hpp" #include "ck/tensor_operation/gpu/device/device_reduce.hpp" #include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" -namespace ck { -template -__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, - const GridDesc_M_K gamma_grid_desc_m_k, - const GridDesc_M_K beta_grid_desc_m_k, - const GridDesc_M_K y_grid_desc_m_k, - index_t num_k_block_tile_iteration, - AccDataType epsilon, - const XDataType* const __restrict__ p_x_global, - const GammaDataType* const __restrict__ p_gamma_global, - const BetaDataType* const __restrict__ p_beta_global, - YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) -{ - GridwiseReduction::Run(x_grid_desc_m_k, - gamma_grid_desc_m_k, - beta_grid_desc_m_k, - y_grid_desc_m_k, - num_k_block_tile_iteration, - epsilon, - p_x_global, - p_gamma_global, - p_beta_global, - p_y_global, - acc_elementwise_op); -}; -} // namespace ck - namespace ck { namespace tensor_operation { namespace device { @@ -58,9 +23,9 @@ namespace device { template + index_t YDstVectorSize, + bool UseWelford = true> struct DeviceNormalizationImpl : public DeviceNormalization { + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); static_assert( ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), @@ -167,51 +134,6 @@ struct DeviceNormalizationImpl : public DeviceNormalization; - using GridwiseNormalizationSweepOnce = - GridwiseNormalizationWelfordVariance_mk_to_mk; - struct Argument : public BaseArgument { Argument(const std::vector lengths, @@ -220,7 +142,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization betaStrides, const std::vector yStrides, const std::vector reduceDims, - AccElementwiseOperation acc_elementwise_op, + YElementwiseOperation y_elementwise_op, double epsilon, const XDataType* p_x, const GammaDataType* p_gamma, @@ -230,9 +152,9 @@ struct DeviceNormalizationImpl : public DeviceNormalization(epsilon); + epsilon_ = static_cast(epsilon); Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); @@ -265,7 +187,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization{}) <= KThreadClusterSize * KThreadSliceSize; } - AccDataType epsilon_; + ComputeDataType epsilon_; const XDataType* p_x_; const GammaDataType* p_gamma_; @@ -278,7 +200,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization betaStrides_; std::vector yStrides_; - AccElementwiseOperation acc_elementwise_op_; + YElementwiseOperation y_elementwise_op_; int blkGroupSize_; int numBlockTileIteration_; @@ -295,23 +217,27 @@ struct DeviceNormalizationImpl : public DeviceNormalization - : kernel_normalization; + auto kernel_main = NormalizationKernelSelector(arg.isSweeponce_); float avg_time = 0; avg_time += launch_and_time_kernel(stream_config, @@ -329,7 +255,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization(p_x), static_cast(p_gamma), @@ -462,8 +388,8 @@ struct DeviceNormalizationImpl : public DeviceNormalization"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp index 89efea4d6c..792ffabcb9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp @@ -4,9 +4,8 @@ #pragma once #include "ck/utility/data_type.hpp" -#include "ck/utility/reduction_common.hpp" + #include "ck/utility/reduction_operator.hpp" -#include "ck/utility/reduction_functions_accumulate.hpp" #include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" #include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" @@ -19,8 +18,8 @@ template ; @@ -59,19 +62,23 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}))); + make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); - using BlockwiseSumReduce = PartitionedBlockwiseReduction; - using ThreadwiseSumReduce = ThreadwiseReduction{}; static constexpr auto I2 = Number<2>{}; - static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; - static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, const GridDesc_M_K& gamma_grid_desc_m_k, const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, index_t num_k_block_tile_iteration, - AccDataType epsilon, + ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) + const YElementwiseOperation y_elementwise_op) { - if constexpr(SweepOnce) - { - num_k_block_tile_iteration = 1; - } - // LDS - __shared__ AccDataType p_reduce_work_buffer[BlockSize]; - - auto y_global_val_buf = make_dynamic_buffer( - p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; auto reduce_work_buf = make_dynamic_buffer(p_reduce_work_buffer, BlockSize); - StaticBuffer - x_thread_buf; + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); - StaticBuffer - gamma_thread_buf; + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer& beta_thread_buf = gamma_thread_buf; + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer - y_thread_buf; + auto& beta_thread_buf = gamma_thread_buf; - StaticBuffer& x_square_thread_buf = y_thread_buf; + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); - StaticBuffer mean_thread_buf; - StaticBuffer + auto& x_square_thread_buf = y_thread_buf; + + StaticBuffer + mean_thread_buf; + StaticBuffer mean_square_thread_buf; - StaticBuffer& var_thread_buf = - mean_square_thread_buf; - - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); - mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); - }); + StaticBuffer& + var_thread_buf = mean_square_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -149,12 +162,8 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_k_cluster_id = thread_cluster_idx[I1]; - using ThreadBufferLengths_M_K = Sequence; - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - x_square_thread_buf(Number{}) = - x_thread_buf(Number{}) * x_thread_buf(Number{}); - }); - }); - - ThreadwiseSumReduce::Reduce(x_thread_buf, mean_thread_buf); - ThreadwiseSumReduce::Reduce(x_square_thread_buf, mean_square_thread_buf); - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - - ++reducedTiles; - } while(reducedTiles < num_k_block_tile_iteration); - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); - mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; - - block_sync_lds(); - - BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); - mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; - - // var(x) = E[x^2] - E[x]^2 - var_thread_buf(I) = - mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); + mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); }); - // y = (x - E[x]) / sqrt(var[x] + epsilon) - auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - - reducedTiles = 0; - do + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) { - if constexpr(!SweepOnce) - { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_x_load.Run(x_grid_desc_m_k, x_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), - x_thread_buf); + x_thread_buf(i)); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; + + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + }); } - threadwise_gamma_load.Run(gamma_grid_desc_m_k, - gamma_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - gamma_thread_buf); + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; - // normalize - y_thread_buf(Number{}) = - (x_thread_buf(Number{}) - mean_thread_buf(iM)) / - sqrt(var_thread_buf(iM) + epsilon); + block_sync_lds(); - // gamma - y_thread_buf(Number{}) = - y_thread_buf(Number{}) * gamma_thread_buf(Number{}); - }); + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); }); - threadwise_beta_load.Run(beta_grid_desc_m_k, - beta_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - beta_thread_buf); - - static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { - constexpr auto offset_m_k = - thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); - - // beta - y_thread_buf(Number{}) = - y_thread_buf(Number{}) + beta_thread_buf(Number{}); - }); - }); - - threadwise_y_store.Run(thread_buffer_desc_m_k, - make_tuple(I0, I0), - y_thread_buf, - y_grid_desc_m_k, - y_global_val_buf); + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - ++reducedTiles; - } while(reducedTiles < num_k_block_tile_iteration); + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp new file mode 100644 index 0000000000..37795fa569 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_selector.hpp @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp" + +namespace ck { +template +__global__ void kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseReduction::Run(x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + num_k_block_tile_iteration, + epsilon, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + y_elementwise_op); +}; + +template +auto NormalizationKernelSelector(bool isSweepOnce) +{ + using GridwiseNormalizationGenericNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationGenericWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + + if constexpr(UseWelford) + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } + else + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp index 70a8c020dd..3a7ae459e5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_normalization_welford_variance.hpp @@ -16,8 +16,8 @@ template ; @@ -56,15 +60,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr auto thread_cluster_desc = make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}))); using ThreadReduceDstDesc_M = decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); using ThreadwiseWelford = - ThreadwiseWelford; + ThreadwiseWelford; - using BlockwiseWelford = BlockwiseWelford; @@ -77,10 +85,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; - static constexpr auto XThreadBufferNumber = Number{}; - static constexpr auto GammaThreadBufferNumber = Number{}; - static constexpr auto BetaThreadBufferNumber = Number{}; - static constexpr auto YThreadBufferNumber = Number{}; + static constexpr auto ThreadBufferNumber = Number{}; __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, int thread_k_cluster_id) @@ -93,7 +98,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk if(kPerBlockTail > 0) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { int thread_max_len = (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; int delta = thread_max_len - kPerBlockTail; @@ -110,59 +115,41 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const GridDesc_M_K& beta_grid_desc_m_k, const GridDesc_M_K& y_grid_desc_m_k, index_t num_k_block_tile_iteration, - AccDataType epsilon, + ComputeDataType epsilon, const XDataType* const __restrict__ p_x_global, const GammaDataType* const __restrict__ p_gamma_global, const BetaDataType* const __restrict__ p_beta_global, YDataType* const __restrict__ p_y_global, - const AccElementwiseOperation acc_elementwise_op) + const YElementwiseOperation y_elementwise_op) { - if constexpr(SweepOnce) - { - num_k_block_tile_iteration = 1; - } - auto y_global_val_buf = make_dynamic_buffer( p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); auto x_thread_buf = generate_tuple( [&](auto) { return StaticBuffer{}; }, - Number{}); + Number{}); auto gamma_thread_buf = generate_tuple( [&](auto) { return StaticBuffer{}; }, - Number{}); + Number{}); - auto beta_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; - auto y_thread_buf = generate_tuple( - [&](auto) { - return StaticBuffer{}; - }, - Number{}); - - StaticBuffer mean_thread_buf; - StaticBuffer var_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; const index_t thread_local_id = get_thread_local_1d_id(); const index_t block_global_id = get_block_1d_id(); @@ -173,12 +160,8 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_k_cluster_id = thread_cluster_idx[I1]; - using ThreadBufferLengths_M_K = Sequence; - constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); - auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2{}([&](auto I) { - mean_thread_buf(I) = type_convert(0.0f); - var_thread_buf(I) = type_convert(0.0f); + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); }); - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_x_load.Run(x_grid_desc_m_k, x_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), x_thread_buf(i)); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); - }); - } - static_for<0, MThreadSliceSize, 1>{}([&](auto I) { - if constexpr(I > 0) - block_sync_lds(); - - int count = threadwise_welford.cur_count_; - BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); - }); - - auto thread_copy_tail_m_k = - (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; - - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); - - for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) - { - if constexpr(!SweepOnce) - { - static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { - threadwise_x_load.Run(x_grid_desc_m_k, - x_global_val_buf, - thread_buffer_desc_m_k, - make_tuple(I0, I0), - x_thread_buf(i)); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); - }); - } - - static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { threadwise_gamma_load.Run(gamma_grid_desc_m_k, gamma_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), gamma_thread_buf(i)); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -330,7 +293,7 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * divisor; - // gamma + // gamma & beta y_thread_buf(iK0)(Number{}) = y_thread_buf(iK0)(Number{}) * gamma_thread_buf(iK0)(Number{}); @@ -338,18 +301,20 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); }); - static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_beta_load.Run(beta_grid_desc_m_k, beta_global_val_buf, thread_buffer_desc_m_k, make_tuple(I0, I0), beta_thread_buf(i)); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - thread_copy_fwd_step_m_k); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); }); static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { - static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { constexpr auto offset_m_k = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); @@ -362,22 +327,134 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk }); }); - static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { threadwise_y_store.Run(thread_buffer_desc_m_k, make_tuple(I0, I0), y_thread_buf(i), y_grid_desc_m_k, y_global_val_buf); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); }); - threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, - 2 * thread_copy_bwd_step_m_k); - threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); - } + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice } }; diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index 4aba0b1192..4febace0b8 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -83,6 +83,11 @@ static inline __host__ bool isnan(int4_t x) }; #endif +static inline __host__ half_t sqrt(half_t x) +{ + return static_cast(std::sqrt(static_cast(x))); +}; + static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); }; @@ -158,6 +163,11 @@ static inline __device__ bool isnan(half_t x) return (xx & 0x7FFF) > 0x7C00; }; +static inline __device__ half_t sqrt(half_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + static inline __device__ float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; static inline __device__ double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp index 8994d9dcb6..beeaa3aa22 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f16_instance.cpp @@ -21,20 +21,25 @@ template // clang-format off using device_normalization_f16_instances = std::tuple < - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp index 4a7e1fd0b9..4d236fb633 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/normalization/device_normalization_f32_instance.cpp @@ -19,17 +19,26 @@ using Pass = ck::tensor_operation::element_wise::PassThrough; template using device_layernorm_f32_instances = std::tuple< // clang-format off - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, // fallback kernel - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, - DeviceNormalizationImpl, + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize> + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, // irregular size + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, DeviceNormalizationImpl, DeviceNormalizationImpl, - DeviceNormalizationImpl + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl, + DeviceNormalizationImpl // clang-format on >; diff --git a/profiler/include/profiler/profile_layernorm_impl.hpp b/profiler/include/profiler/profile_layernorm_impl.hpp index eb21d4a586..7dd90d0797 100644 --- a/profiler/include/profiler/profile_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_layernorm_impl.hpp @@ -19,7 +19,7 @@ namespace profiler { template bool profile_layernorm_impl(int do_verification, @@ -86,7 +86,7 @@ bool profile_layernorm_impl(int do_verification, using DeviceOp = ck::tensor_operation::device::DeviceNormalization; @@ -181,8 +181,8 @@ bool profile_layernorm_impl(int do_verification, { y_dev.FromDevice(y.mData.data()); - bool pass = ck::utils::check_err( - y.mData, host_y.mData, "Error: Incorrect results d1", 1e-3, 1e-3); + bool pass = + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); if(do_log) { diff --git a/test/normalization/test_groupnorm_fp16.cpp b/test/normalization/test_groupnorm_fp16.cpp index 636e522dce..60d3b13959 100644 --- a/test/normalization/test_groupnorm_fp16.cpp +++ b/test/normalization/test_groupnorm_fp16.cpp @@ -12,11 +12,11 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -36,7 +36,7 @@ class TestGroupnorm : public ::testing::Test ck::profiler::profile_groupnorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -44,7 +44,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization/test_groupnorm_fp32.cpp b/test/normalization/test_groupnorm_fp32.cpp index ef492664bf..3542f73a62 100644 --- a/test/normalization/test_groupnorm_fp32.cpp +++ b/test/normalization/test_groupnorm_fp32.cpp @@ -12,11 +12,11 @@ template class TestGroupnorm : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -34,7 +34,7 @@ class TestGroupnorm : public ::testing::Test ck::profiler::profile_groupnorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); } @@ -42,7 +42,7 @@ class TestGroupnorm : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestGroupnorm, KernelTypes); diff --git a/test/normalization/test_layernorm2d_fp16.cpp b/test/normalization/test_layernorm2d_fp16.cpp index eeb8ec150a..d627cbe7f1 100644 --- a/test/normalization/test_layernorm2d_fp16.cpp +++ b/test/normalization/test_layernorm2d_fp16.cpp @@ -12,11 +12,11 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); @@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); diff --git a/test/normalization/test_layernorm2d_fp32.cpp b/test/normalization/test_layernorm2d_fp32.cpp index f555b42592..de4133aa83 100644 --- a/test/normalization/test_layernorm2d_fp32.cpp +++ b/test/normalization/test_layernorm2d_fp32.cpp @@ -12,11 +12,11 @@ template class TestLayernorm2d : public ::testing::Test { protected: - using XDataType = std::tuple_element_t<0, Tuple>; - using GammaDataType = std::tuple_element_t<1, Tuple>; - using BetaDataType = std::tuple_element_t<2, Tuple>; - using AccDataType = std::tuple_element_t<3, Tuple>; - using YDataType = std::tuple_element_t<4, Tuple>; + using XDataType = std::tuple_element_t<0, Tuple>; + using GammaDataType = std::tuple_element_t<1, Tuple>; + using BetaDataType = std::tuple_element_t<2, Tuple>; + using ComputeDataType = std::tuple_element_t<3, Tuple>; + using YDataType = std::tuple_element_t<4, Tuple>; void Run() { @@ -29,7 +29,7 @@ class TestLayernorm2d : public ::testing::Test bool success = ck::profiler::profile_layernorm_impl(true, 2, false, false, length); EXPECT_TRUE(success); @@ -38,7 +38,7 @@ class TestLayernorm2d : public ::testing::Test }; using KernelTypes = ::testing::Types< - // XDataType, GammaDataType, BetaDataType, AccDataType, YDataType> + // XDataType, GammaDataType, BetaDataType, ComputeDataType, YDataType> std::tuple>; TYPED_TEST_SUITE(TestLayernorm2d, KernelTypes); From 24c9ee1d22e02579d2e5db255722debb020e133b Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 15 Feb 2023 12:00:58 -0600 Subject: [PATCH 16/22] Add contraction_fp64 example (#570) * add contraction_bilinear * add contraction_scale_xdl_fp64 * reduce tile size to avoid register spill --------- Co-authored-by: root --- example/26_contraction/CMakeLists.txt | 3 + .../contraction_bilinear_xdl_fp64.cpp | 427 ++++++++++++++++++ .../contraction_scale_xdl_fp64.cpp | 409 +++++++++++++++++ .../element/binary_element_wise_operation.hpp | 7 + .../element/unary_element_wise_operation.hpp | 6 + script/cmake-ck-dev.sh | 4 +- 6 files changed, 854 insertions(+), 2 deletions(-) create mode 100644 example/26_contraction/contraction_bilinear_xdl_fp64.cpp create mode 100644 example/26_contraction/contraction_scale_xdl_fp64.cpp diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 87f4750e3b..c58751f0dd 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,2 +1,5 @@ add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) + +add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) +add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp new file mode 100644 index 0000000000..9a000377bb --- /dev/null +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -0,0 +1,427 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DDataType = F64; +using DsDataType = ck::Tuple; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKNN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp new file mode 100644 index 0000000000..38ed60266d --- /dev/null +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -0,0 +1,409 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F64; +using BDataType = F64; +using AccDataType = F64; +using CShuffleDataType = F64; +using DsDataType = ck::Tuple<>; +using EDataType = F64; + +static constexpr ck::index_t NumDimM = 2; +static constexpr ck::index_t NumDimN = 2; +static constexpr ck::index_t NumDimK = 2; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on + +using DeviceOpInstance = DeviceOpInstanceKKN; + +// hardcoded for NumDimM == NumDimN == NumDimK == 2 +template = false> +struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator +{ + // Argument + struct Argument : public ck::tensor_operation::device::BaseArgument + { + Argument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_ms_ks_{a_ms_ks}, + b_ns_ks_{b_ns_ks}, + e_ms_ns_{e_ms_ns}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_ms_ks_; + const Tensor& b_ns_ks_; + Tensor& e_ms_ns_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public ck::tensor_operation::device::BaseInvoker + { + using Argument = ReferenceContraction_M2_N2_K2::Argument; + + float Run(const Argument& arg) + { + auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { + const int K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; + const int K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k0 = 0; k0 < K0; ++k0) + { + for(int k1 = 0; k1 < K1; ++k1) + { + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); + + v_acc += v_a * v_b; + } + } + + AccDataType v_c; + + arg.cde_element_op_(v_c, v_acc); + + arg.e_ms_ns_(m0, m1, n0, n1) = v_c; + }; + + make_ParallelTensorFunctor(f_ms_ns, + arg.e_ms_ns_.mDesc.GetLengths()[0], + arg.e_ms_ns_.mDesc.GetLengths()[1], + arg.e_ms_ns_.mDesc.GetLengths()[2], + arg.e_ms_ns_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const ck::tensor_operation::device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const ck::tensor_operation::device::BaseArgument*) override + { + return true; + } + + static auto MakeArgument(const Tensor& a_ms_ks, + const Tensor& b_ns_ks, + Tensor& e_ms_ns, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_ms_ks, b_ns_ks, e_ms_ns, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceContraction_M2_N2_K2" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = ReferenceContraction_M2_N2_K2; + + auto ref_gemm = ReferenceOpInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index d0ba07e5a9..69fa75c3fd 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -150,6 +150,13 @@ struct Bilinear template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + template <> __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index fbdfe9262f..2167a79e01 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -95,6 +95,12 @@ struct Scale y = scale_ * x; }; + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = scale_ * x; + }; + float scale_; }; diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 2e605ce8de..3e530478b0 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -10,8 +10,8 @@ cmake -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_BUILD_TYPE=Release \ --D BUILD_DEV=ON \ --D GPU_TARGETS="gfx908;gfx90a" \ +-D BUILD_DEV=OFF \ +-D GPU_TARGETS="gfx90a" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D USE_BITINT_EXTENSION_INT4=OFF \ ${MY_PROJECT_SOURCE} From 19490ac4f71735495bcdd080a2531d679c3573e9 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 15 Feb 2023 10:07:21 -0800 Subject: [PATCH 17/22] Clean up kernel launch output (#569) * clean up output from kernel_launch * set RUN_WARMUP to 0 by default * split the warm-up into a separate issue --------- Co-authored-by: zjing14 --- include/ck/host_utility/kernel_launch.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/include/ck/host_utility/kernel_launch.hpp b/include/ck/host_utility/kernel_launch.hpp index ed6e2f0ba1..24f2121674 100644 --- a/include/ck/host_utility/kernel_launch.hpp +++ b/include/ck/host_utility/kernel_launch.hpp @@ -20,6 +20,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config, #if CK_TIME_KERNEL if(stream_config.time_kernel_) { +#if DEBUG_LOG printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", __func__, grid_dim.x, @@ -29,15 +30,15 @@ float launch_and_time_kernel(const StreamConfig& stream_config, block_dim.y, block_dim.z); - const int nrepeat = 10; - printf("Warm up 1 time\n"); - +#endif // warm up kernel<<>>(args...); + const int nrepeat = 10; +#if DEBUG_LOG printf("Start running %d times...\n", nrepeat); - +#endif hipEvent_t start, stop; hip_check_error(hipEventCreate(&start)); From cb3fac4d2a4e8bcac0e15118d7afd4af93301132 Mon Sep 17 00:00:00 2001 From: pmaybank <113125070+pmaybank@users.noreply.github.com> Date: Wed, 15 Feb 2023 23:17:46 +0000 Subject: [PATCH 18/22] Sphinx doc (#581) * New docs directory with minimal config * Based on docs directory of rocBLAS * Config for running Doxygen then Sphinx to generate HTML * Add minimal content - intro to doc * Add some boilerplate sections to doc * content still needs to be done, * e.g., need to generate API documentation using Doxygen * need to write contributor guide * Start Softmax section of Support Primitives doc * Written as a test bed for typesetting math content * Need to decide how much detail to go into * add doc directories to git ignore file. * Minor edits - new line at EOF, change year in copyright notices * Port Markdown files to ReStructuredText * Copy Markdown files from pre-existing doc directory to docs directory * Convert to reStructured Text (rst) - section headings, links, tables have a different syntax in rst * New rst files added to index - can generate HTML with same style as HTML generated from rst files in previous commits * Intention is to make all the content in doc redundant and use rst throughout rather than mix of md and rst * Extend Softmax section of Primitives Guide * rename l to z * add material on applying softmax row-wise to matrix * define macro for diag operator (represents diagonal matrix) --------- Co-authored-by: zjing14 --- .gitignore | 4 + docs/Doxyfile | 2453 ++++++++++++++++++++ docs/run_doc.sh | 15 + docs/run_doxygen.sh | 10 + docs/source/API_Reference_Guide.rst | 23 + docs/source/Contributors_Guide.rst | 8 + docs/source/Disclaimer.rst | 13 + docs/source/Linux_Install_Guide.rst | 15 + docs/source/Makefile | 20 + docs/source/Supported_Primitives_Guide.rst | 75 + docs/source/conf.py | 216 ++ docs/source/dockerhub.rst | 96 + docs/source/index.rst | 16 + docs/source/rocm_logo.png | Bin 0 -> 355437 bytes docs/source/tutorial_hello_world.rst | 174 ++ 15 files changed, 3138 insertions(+) create mode 100644 docs/Doxyfile create mode 100755 docs/run_doc.sh create mode 100755 docs/run_doxygen.sh create mode 100644 docs/source/API_Reference_Guide.rst create mode 100644 docs/source/Contributors_Guide.rst create mode 100644 docs/source/Disclaimer.rst create mode 100644 docs/source/Linux_Install_Guide.rst create mode 100644 docs/source/Makefile create mode 100644 docs/source/Supported_Primitives_Guide.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/dockerhub.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/rocm_logo.png create mode 100644 docs/source/tutorial_hello_world.rst diff --git a/.gitignore b/.gitignore index 71059ec4d9..5667695bb5 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,7 @@ build* # GDB temporary files .gdb_history install.dir* + +# directories containing generated documentation +docs/source/_build/ +docs/docBin/ diff --git a/docs/Doxyfile b/docs/Doxyfile new file mode 100644 index 0000000000..958b3b6f44 --- /dev/null +++ b/docs/Doxyfile @@ -0,0 +1,2453 @@ +# Doxyfile 1.8.10 + +# This file describes the settings to be used by the documentation system +# doxygen (www.doxygen.org) for a project. +# +# All text after a double hash (##) is considered a comment and is placed in +# front of the TAG it is preceding. +# +# All text after a single hash (#) is considered a comment and will be ignored. +# The format is: +# TAG = value [value, ...] +# For lists, items can also be appended using: +# TAG += value [value, ...] +# Values that contain spaces should be placed between quotes (\" \"). + +#--------------------------------------------------------------------------- +# Project related configuration options +#--------------------------------------------------------------------------- + +# This tag specifies the encoding used for all characters in the config file +# that follow. The default is UTF-8 which is also the encoding used for all text +# before the first occurrence of this tag. Doxygen uses libiconv (or the iconv +# built into libc) for the transcoding. See http://www.gnu.org/software/libiconv +# for the list of possible encodings. +# The default value is: UTF-8. + +DOXYFILE_ENCODING = UTF-8 + +# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by +# double-quotes, unless you are using Doxywizard) that should identify the +# project for which the documentation is generated. This name is used in the +# title of most generated pages and in a few other places. +# The default value is: My Project. + +PROJECT_NAME = "ck" + +# The PROJECT_NUMBER tag can be used to enter a project or revision number. This +# could be handy for archiving the generated documentation or if some version +# control system is used. + +PROJECT_NUMBER = v3.0.1.0 + +# Using the PROJECT_BRIEF tag one can provide an optional one line description +# for a project that appears at the top of each page and should give viewer a +# quick idea about the purpose of the project. Keep the description short. + +PROJECT_BRIEF = "prototype interfaces compatible with ROCm platform and HiP" + +# With the PROJECT_LOGO tag one can specify a logo or an icon that is included +# in the documentation. The maximum height of the logo should not exceed 55 +# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy +# the logo to the output directory. + +PROJECT_LOGO = ./rocm.jpg + +# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path +# into which the generated documentation will be written. If a relative path is +# entered, it will be relative to the location where doxygen was started. If +# left blank the current directory will be used. + +OUTPUT_DIRECTORY = docBin + +# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub- +# directories (in 2 levels) under the output directory of each output format and +# will distribute the generated files over these directories. Enabling this +# option can be useful when feeding doxygen a huge amount of source files, where +# putting all generated files in the same directory would otherwise causes +# performance problems for the file system. +# The default value is: NO. + +CREATE_SUBDIRS = NO + +# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII +# characters to appear in the names of generated files. If set to NO, non-ASCII +# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode +# U+3044. +# The default value is: NO. + +ALLOW_UNICODE_NAMES = NO + +# The OUTPUT_LANGUAGE tag is used to specify the language in which all +# documentation generated by doxygen is written. Doxygen will use this +# information to generate all constant output in the proper language. +# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese, +# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States), +# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian, +# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages), +# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian, +# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian, +# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish, +# Ukrainian and Vietnamese. +# The default value is: English. + +OUTPUT_LANGUAGE = English + +# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member +# descriptions after the members that are listed in the file and class +# documentation (similar to Javadoc). Set to NO to disable this. +# The default value is: YES. + +BRIEF_MEMBER_DESC = YES + +# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief +# description of a member or function before the detailed description +# +# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the +# brief descriptions will be completely suppressed. +# The default value is: YES. + +REPEAT_BRIEF = YES + +# This tag implements a quasi-intelligent brief description abbreviator that is +# used to form the text in various listings. Each string in this list, if found +# as the leading text of the brief description, will be stripped from the text +# and the result, after processing the whole list, is used as the annotated +# text. Otherwise, the brief description is used as-is. If left blank, the +# following values are used ($name is automatically replaced with the name of +# the entity):The $name class, The $name widget, The $name file, is, provides, +# specifies, contains, represents, a, an and the. + +ABBREVIATE_BRIEF = "The $name class" \ + "The $name widget" \ + "The $name file" \ + is \ + provides \ + specifies \ + contains \ + represents \ + a \ + an \ + the + +# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then +# doxygen will generate a detailed section even if there is only a brief +# description. +# The default value is: NO. + +ALWAYS_DETAILED_SEC = NO + +# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all +# inherited members of a class in the documentation of that class as if those +# members were ordinary class members. Constructors, destructors and assignment +# operators of the base classes will not be shown. +# The default value is: NO. + +INLINE_INHERITED_MEMB = NO + +# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path +# before files name in the file list and in the header files. If set to NO the +# shortest path that makes the file name unique will be used +# The default value is: YES. + +FULL_PATH_NAMES = YES + +# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path. +# Stripping is only done if one of the specified strings matches the left-hand +# part of the path. The tag can be used to show relative paths in the file list. +# If left blank the directory from which doxygen is run is used as the path to +# strip. +# +# Note that you can specify absolute paths here, but also relative paths, which +# will be relative from the directory where doxygen is started. +# This tag requires that the tag FULL_PATH_NAMES is set to YES. + +STRIP_FROM_PATH = + +# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the +# path mentioned in the documentation of a class, which tells the reader which +# header file to include in order to use a class. If left blank only the name of +# the header file containing the class definition is used. Otherwise one should +# specify the list of include paths that are normally passed to the compiler +# using the -I flag. + +STRIP_FROM_INC_PATH = + +# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but +# less readable) file names. This can be useful is your file systems doesn't +# support long names like on DOS, Mac, or CD-ROM. +# The default value is: NO. + +SHORT_NAMES = NO + +# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the +# first line (until the first dot) of a Javadoc-style comment as the brief +# description. If set to NO, the Javadoc-style will behave just like regular Qt- +# style comments (thus requiring an explicit @brief command for a brief +# description.) +# The default value is: NO. + +JAVADOC_AUTOBRIEF = NO + +# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first +# line (until the first dot) of a Qt-style comment as the brief description. If +# set to NO, the Qt-style will behave just like regular Qt-style comments (thus +# requiring an explicit \brief command for a brief description.) +# The default value is: NO. + +QT_AUTOBRIEF = NO + +# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a +# multi-line C++ special comment block (i.e. a block of //! or /// comments) as +# a brief description. This used to be the default behavior. The new default is +# to treat a multi-line C++ comment block as a detailed description. Set this +# tag to YES if you prefer the old behavior instead. +# +# Note that setting this tag to YES also means that rational rose comments are +# not recognized any more. +# The default value is: NO. + +MULTILINE_CPP_IS_BRIEF = NO + +# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the +# documentation from any documented member that it re-implements. +# The default value is: YES. + +INHERIT_DOCS = YES + +# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new +# page for each member. If set to NO, the documentation of a member will be part +# of the file/class/namespace that contains it. +# The default value is: NO. + +SEPARATE_MEMBER_PAGES = NO + +# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen +# uses this value to replace tabs by spaces in code fragments. +# Minimum value: 1, maximum value: 16, default value: 4. + +TAB_SIZE = 4 + +# This tag can be used to specify a number of aliases that act as commands in +# the documentation. An alias has the form: +# name=value +# For example adding +# "sideeffect=@par Side Effects:\n" +# will allow you to put the command \sideeffect (or @sideeffect) in the +# documentation, which will result in a user-defined paragraph with heading +# "Side Effects:". You can put \n's in the value part of an alias to insert +# newlines. + +ALIASES = + +# This tag can be used to specify a number of word-keyword mappings (TCL only). +# A mapping has the form "name=value". For example adding "class=itcl::class" +# will allow you to use the command class in the itcl::class meaning. + +TCL_SUBST = + +# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources +# only. Doxygen will then generate output that is more tailored for C. For +# instance, some of the names that are used will be different. The list of all +# members will be omitted, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_FOR_C = NO + +# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or +# Python sources only. Doxygen will then generate output that is more tailored +# for that language. For instance, namespaces will be presented as packages, +# qualified scopes will look different, etc. +# The default value is: NO. + +OPTIMIZE_OUTPUT_JAVA = NO + +# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran +# sources. Doxygen will then generate output that is tailored for Fortran. +# The default value is: NO. + +OPTIMIZE_FOR_FORTRAN = NO + +# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL +# sources. Doxygen will then generate output that is tailored for VHDL. +# The default value is: NO. + +OPTIMIZE_OUTPUT_VHDL = NO + +# Doxygen selects the parser to use depending on the extension of the files it +# parses. With this tag you can assign which parser to use for a given +# extension. Doxygen has a built-in mapping, but you can override or extend it +# using this tag. The format is ext=language, where ext is a file extension, and +# language is one of the parsers supported by doxygen: IDL, Java, Javascript, +# C#, C, C++, D, PHP, Objective-C, Python, Fortran (fixed format Fortran: +# FortranFixed, free formatted Fortran: FortranFree, unknown formatted Fortran: +# Fortran. In the later case the parser tries to guess whether the code is fixed +# or free formatted code, this is the default for Fortran type files), VHDL. For +# instance to make doxygen treat .inc files as Fortran files (default is PHP), +# and .f files as C (default is Fortran), use: inc=Fortran f=C. +# +# Note: For files without extension you can use no_extension as a placeholder. +# +# Note that for custom extensions you also need to set FILE_PATTERNS otherwise +# the files are not read by doxygen. + +EXTENSION_MAPPING = + +# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments +# according to the Markdown format, which allows for more readable +# documentation. See http://daringfireball.net/projects/markdown/ for details. +# The output of markdown processing is further processed by doxygen, so you can +# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in +# case of backward compatibilities issues. +# The default value is: YES. + +MARKDOWN_SUPPORT = YES + +# When enabled doxygen tries to link words that correspond to documented +# classes, or namespaces to their corresponding documentation. Such a link can +# be prevented in individual cases by putting a % sign in front of the word or +# globally by setting AUTOLINK_SUPPORT to NO. +# The default value is: YES. + +AUTOLINK_SUPPORT = YES + +# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want +# to include (a tag file for) the STL sources as input, then you should set this +# tag to YES in order to let doxygen match functions declarations and +# definitions whose arguments contain STL classes (e.g. func(std::string); +# versus func(std::string) {}). This also make the inheritance and collaboration +# diagrams that involve STL classes more complete and accurate. +# The default value is: NO. + +BUILTIN_STL_SUPPORT = YES + +# If you use Microsoft's C++/CLI language, you should set this option to YES to +# enable parsing support. +# The default value is: NO. + +CPP_CLI_SUPPORT = NO + +# Set the SIP_SUPPORT tag to YES if your project consists of sip (see: +# http://www.riverbankcomputing.co.uk/software/sip/intro) sources only. Doxygen +# will parse them like normal C++ but will assume all classes use public instead +# of private inheritance when no explicit protection keyword is present. +# The default value is: NO. + +SIP_SUPPORT = NO + +# For Microsoft's IDL there are propget and propput attributes to indicate +# getter and setter methods for a property. Setting this option to YES will make +# doxygen to replace the get and set methods by a property in the documentation. +# This will only work if the methods are indeed getting or setting a simple +# type. If this is not the case, or you want to show the methods anyway, you +# should set this option to NO. +# The default value is: YES. + +IDL_PROPERTY_SUPPORT = YES + +# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC +# tag is set to YES then doxygen will reuse the documentation of the first +# member in the group (if any) for the other members of the group. By default +# all members of a group must be documented explicitly. +# The default value is: NO. + +DISTRIBUTE_GROUP_DOC = YES + +# If one adds a struct or class to a group and this option is enabled, then also +# any nested class or struct is added to the same group. By default this option +# is disabled and one has to add nested compounds explicitly via \ingroup. +# The default value is: NO. + +GROUP_NESTED_COMPOUNDS = NO + +# Set the SUBGROUPING tag to YES to allow class member groups of the same type +# (for instance a group of public functions) to be put as a subgroup of that +# type (e.g. under the Public Functions section). Set it to NO to prevent +# subgrouping. Alternatively, this can be done per class using the +# \nosubgrouping command. +# The default value is: YES. + +SUBGROUPING = YES + +# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions +# are shown inside the group in which they are included (e.g. using \ingroup) +# instead of on a separate page (for HTML and Man pages) or section (for LaTeX +# and RTF). +# +# Note that this feature does not work in combination with +# SEPARATE_MEMBER_PAGES. +# The default value is: NO. + +INLINE_GROUPED_CLASSES = NO + +# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions +# with only public data fields or simple typedef fields will be shown inline in +# the documentation of the scope in which they are defined (i.e. file, +# namespace, or group documentation), provided this scope is documented. If set +# to NO, structs, classes, and unions are shown on a separate page (for HTML and +# Man pages) or section (for LaTeX and RTF). +# The default value is: NO. + +INLINE_SIMPLE_STRUCTS = NO + +# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or +# enum is documented as struct, union, or enum with the name of the typedef. So +# typedef struct TypeS {} TypeT, will appear in the documentation as a struct +# with name TypeT. When disabled the typedef will appear as a member of a file, +# namespace, or class. And the struct will be named TypeS. This can typically be +# useful for C code in case the coding convention dictates that all compound +# types are typedef'ed and only the typedef is referenced, never the tag name. +# The default value is: NO. + +TYPEDEF_HIDES_STRUCT = YES + +# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This +# cache is used to resolve symbols given their name and scope. Since this can be +# an expensive process and often the same symbol appears multiple times in the +# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small +# doxygen will become slower. If the cache is too large, memory is wasted. The +# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range +# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536 +# symbols. At the end of a run doxygen will report the cache usage and suggest +# the optimal cache size from a speed point of view. +# Minimum value: 0, maximum value: 9, default value: 0. + +LOOKUP_CACHE_SIZE = 0 + +#--------------------------------------------------------------------------- +# Build related configuration options +#--------------------------------------------------------------------------- + +# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in +# documentation are documented, even if no documentation was available. Private +# class members and static file members will be hidden unless the +# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES. +# Note: This will also disable the warnings about undocumented members that are +# normally produced when WARNINGS is set to YES. +# The default value is: NO. + +EXTRACT_ALL = YES + +# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will +# be included in the documentation. +# The default value is: NO. + +EXTRACT_PRIVATE = NO + +# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal +# scope will be included in the documentation. +# The default value is: NO. + +EXTRACT_PACKAGE = NO + +# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be +# included in the documentation. +# The default value is: NO. + +EXTRACT_STATIC = NO + +# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined +# locally in source files will be included in the documentation. If set to NO, +# only classes defined in header files are included. Does not have any effect +# for Java sources. +# The default value is: YES. + +EXTRACT_LOCAL_CLASSES = YES + +# This flag is only useful for Objective-C code. If set to YES, local methods, +# which are defined in the implementation section but not in the interface are +# included in the documentation. If set to NO, only methods in the interface are +# included. +# The default value is: NO. + +EXTRACT_LOCAL_METHODS = NO + +# If this flag is set to YES, the members of anonymous namespaces will be +# extracted and appear in the documentation as a namespace called +# 'anonymous_namespace{file}', where file will be replaced with the base name of +# the file that contains the anonymous namespace. By default anonymous namespace +# are hidden. +# The default value is: NO. + +EXTRACT_ANON_NSPACES = NO + +# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all +# undocumented members inside documented classes or files. If set to NO these +# members will be included in the various overviews, but no documentation +# section is generated. This option has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_MEMBERS = NO + +# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all +# undocumented classes that are normally visible in the class hierarchy. If set +# to NO, these classes will be included in the various overviews. This option +# has no effect if EXTRACT_ALL is enabled. +# The default value is: NO. + +HIDE_UNDOC_CLASSES = NO + +# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend +# (class|struct|union) declarations. If set to NO, these declarations will be +# included in the documentation. +# The default value is: NO. + +HIDE_FRIEND_COMPOUNDS = NO + +# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any +# documentation blocks found inside the body of a function. If set to NO, these +# blocks will be appended to the function's detailed documentation block. +# The default value is: NO. + +HIDE_IN_BODY_DOCS = NO + +# The INTERNAL_DOCS tag determines if documentation that is typed after a +# \internal command is included. If the tag is set to NO then the documentation +# will be excluded. Set it to YES to include the internal documentation. +# The default value is: NO. + +INTERNAL_DOCS = NO + +# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file +# names in lower-case letters. If set to YES, upper-case letters are also +# allowed. This is useful if you have classes or files whose names only differ +# in case and if your file system supports case sensitive file names. Windows +# and Mac users are advised to set this option to NO. +# The default value is: system dependent. + +CASE_SENSE_NAMES = NO + +# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with +# their full class and namespace scopes in the documentation. If set to YES, the +# scope will be hidden. +# The default value is: NO. + +HIDE_SCOPE_NAMES = NO + +# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will +# append additional text to a page's title, such as Class Reference. If set to +# YES the compound reference will be hidden. +# The default value is: NO. + +HIDE_COMPOUND_REFERENCE= NO + +# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of +# the files that are included by a file in the documentation of that file. +# The default value is: YES. + +SHOW_INCLUDE_FILES = YES + +# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each +# grouped member an include statement to the documentation, telling the reader +# which file to include in order to use the member. +# The default value is: NO. + +SHOW_GROUPED_MEMB_INC = NO + +# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include +# files with double quotes in the documentation rather than with sharp brackets. +# The default value is: NO. + +FORCE_LOCAL_INCLUDES = NO + +# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the +# documentation for inline members. +# The default value is: YES. + +INLINE_INFO = YES + +# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the +# (detailed) documentation of file and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. +# The default value is: YES. + +SORT_MEMBER_DOCS = YES + +# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief +# descriptions of file, namespace and class members alphabetically by member +# name. If set to NO, the members will appear in declaration order. Note that +# this will also influence the order of the classes in the class list. +# The default value is: NO. + +SORT_BRIEF_DOCS = NO + +# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the +# (brief and detailed) documentation of class members so that constructors and +# destructors are listed first. If set to NO the constructors will appear in the +# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS. +# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief +# member documentation. +# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting +# detailed member documentation. +# The default value is: NO. + +SORT_MEMBERS_CTORS_1ST = NO + +# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy +# of group names into alphabetical order. If set to NO the group names will +# appear in their defined order. +# The default value is: NO. + +SORT_GROUP_NAMES = NO + +# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by +# fully-qualified names, including namespaces. If set to NO, the class list will +# be sorted only by class name, not including the namespace part. +# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES. +# Note: This option applies only to the class list, not to the alphabetical +# list. +# The default value is: NO. + +SORT_BY_SCOPE_NAME = NO + +# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper +# type resolution of all parameters of a function it will reject a match between +# the prototype and the implementation of a member function even if there is +# only one candidate or it is obvious which candidate to choose by doing a +# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still +# accept a match between prototype and implementation in such cases. +# The default value is: NO. + +STRICT_PROTO_MATCHING = NO + +# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo +# list. This list is created by putting \todo commands in the documentation. +# The default value is: YES. + +GENERATE_TODOLIST = YES + +# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test +# list. This list is created by putting \test commands in the documentation. +# The default value is: YES. + +GENERATE_TESTLIST = YES + +# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug +# list. This list is created by putting \bug commands in the documentation. +# The default value is: YES. + +GENERATE_BUGLIST = YES + +# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO) +# the deprecated list. This list is created by putting \deprecated commands in +# the documentation. +# The default value is: YES. + +GENERATE_DEPRECATEDLIST= YES + +# The ENABLED_SECTIONS tag can be used to enable conditional documentation +# sections, marked by \if ... \endif and \cond +# ... \endcond blocks. + +ENABLED_SECTIONS = + +# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the +# initial value of a variable or macro / define can have for it to appear in the +# documentation. If the initializer consists of more lines than specified here +# it will be hidden. Use a value of 0 to hide initializers completely. The +# appearance of the value of individual variables and macros / defines can be +# controlled using \showinitializer or \hideinitializer command in the +# documentation regardless of this setting. +# Minimum value: 0, maximum value: 10000, default value: 30. + +MAX_INITIALIZER_LINES = 30 + +# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at +# the bottom of the documentation of classes and structs. If set to YES, the +# list will mention the files that were used to generate the documentation. +# The default value is: YES. + +SHOW_USED_FILES = YES + +# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This +# will remove the Files entry from the Quick Index and from the Folder Tree View +# (if specified). +# The default value is: YES. + +SHOW_FILES = YES + +# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces +# page. This will remove the Namespaces entry from the Quick Index and from the +# Folder Tree View (if specified). +# The default value is: YES. + +SHOW_NAMESPACES = YES + +# The FILE_VERSION_FILTER tag can be used to specify a program or script that +# doxygen should invoke to get the current version for each file (typically from +# the version control system). Doxygen will invoke the program by executing (via +# popen()) the command command input-file, where command is the value of the +# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided +# by doxygen. Whatever the program writes to standard output is used as the file +# version. For an example see the documentation. + +FILE_VERSION_FILTER = + +# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed +# by doxygen. The layout file controls the global structure of the generated +# output files in an output format independent way. To create the layout file +# that represents doxygen's defaults, run doxygen with the -l option. You can +# optionally specify a file name after the option, if omitted DoxygenLayout.xml +# will be used as the name of the layout file. +# +# Note that if you run doxygen from a directory containing a file called +# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE +# tag is left empty. + +LAYOUT_FILE = + +# The CITE_BIB_FILES tag can be used to specify one or more bib files containing +# the reference definitions. This must be a list of .bib files. The .bib +# extension is automatically appended if omitted. This requires the bibtex tool +# to be installed. See also http://en.wikipedia.org/wiki/BibTeX for more info. +# For LaTeX the style of the bibliography can be controlled using +# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the +# search path. See also \cite for info how to create references. + +CITE_BIB_FILES = + +#--------------------------------------------------------------------------- +# Configuration options related to warning and progress messages +#--------------------------------------------------------------------------- + +# The QUIET tag can be used to turn on/off the messages that are generated to +# standard output by doxygen. If QUIET is set to YES this implies that the +# messages are off. +# The default value is: NO. + +QUIET = NO + +# The WARNINGS tag can be used to turn on/off the warning messages that are +# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES +# this implies that the warnings are on. +# +# Tip: Turn warnings on while writing the documentation. +# The default value is: YES. + +WARNINGS = YES + +# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate +# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag +# will automatically be disabled. +# The default value is: YES. + +WARN_IF_UNDOCUMENTED = YES + +# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for +# potential errors in the documentation, such as not documenting some parameters +# in a documented function, or documenting parameters that don't exist or using +# markup commands wrongly. +# The default value is: YES. + +WARN_IF_DOC_ERROR = YES + +# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that +# are documented, but have no documentation for their parameters or return +# value. If set to NO, doxygen will only warn about wrong or incomplete +# parameter documentation, but not about the absence of documentation. +# The default value is: NO. + +WARN_NO_PARAMDOC = NO + +# The WARN_FORMAT tag determines the format of the warning messages that doxygen +# can produce. The string should contain the $file, $line, and $text tags, which +# will be replaced by the file and line number from which the warning originated +# and the warning text. Optionally the format may contain $version, which will +# be replaced by the version of the file (if it could be obtained via +# FILE_VERSION_FILTER) +# The default value is: $file:$line: $text. + +WARN_FORMAT = "$file:$line: $text" + +# The WARN_LOGFILE tag can be used to specify a file to which warning and error +# messages should be written. If left blank the output is written to standard +# error (stderr). + +WARN_LOGFILE = + +#--------------------------------------------------------------------------- +# Configuration options related to the input files +#--------------------------------------------------------------------------- + +# The INPUT tag is used to specify the files and/or directories that contain +# documented source files. You may enter file names like myfile.cpp or +# directories like /usr/src/myproject. Separate the files or directories with +# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING +# Note: If this tag is empty the current directory is searched. + +INPUT = ../library/include \ + ../library/include/internal + +# This tag can be used to specify the character encoding of the source files +# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses +# libiconv (or the iconv built into libc) for the transcoding. See the libiconv +# documentation (see: http://www.gnu.org/software/libiconv) for the list of +# possible encodings. +# The default value is: UTF-8. + +INPUT_ENCODING = UTF-8 + +# If the value of the INPUT tag contains directories, you can use the +# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and +# *.h) to filter out the source-files in the directories. +# +# Note that for custom extensions or not directly supported extensions you also +# need to set EXTENSION_MAPPING for the extension otherwise the files are not +# read by doxygen. +# +# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp, +# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h, +# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc, +# *.m, *.markdown, *.md, *.mm, *.dox, *.py, *.f90, *.f, *.for, *.tcl, *.vhd, +# *.vhdl, *.ucf, *.qsf, *.as and *.js. + +FILE_PATTERNS = *.c \ + *.cc \ + *.cxx \ + *.cpp \ + *.c++ \ + *.java \ + *.ii \ + *.ixx \ + *.ipp \ + *.i++ \ + *.inl \ + *.idl \ + *.ddl \ + *.odl \ + *.h \ + *.hh \ + *.hxx \ + *.hpp \ + *.h++ \ + *.cs \ + *.d \ + *.php \ + *.php4 \ + *.php5 \ + *.phtml \ + *.inc \ + *.m \ + *.markdown \ + *.md \ + *.mm \ + *.dox \ + *.py \ + *.tcl \ + *.vhd \ + *.vhdl \ + *.ucf \ + *.qsf \ + *.as \ + *.js + +# The RECURSIVE tag can be used to specify whether or not subdirectories should +# be searched for input files as well. +# The default value is: NO. + +RECURSIVE = NO + +# The EXCLUDE tag can be used to specify files and/or directories that should be +# excluded from the INPUT source files. This way you can easily exclude a +# subdirectory from a directory tree whose root is specified with the INPUT tag. +# +# Note that relative paths are relative to the directory from which doxygen is +# run. + +EXCLUDE = + +# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or +# directories that are symbolic links (a Unix file system feature) are excluded +# from the input. +# The default value is: NO. + +EXCLUDE_SYMLINKS = NO + +# If the value of the INPUT tag contains directories, you can use the +# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude +# certain files from those directories. +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories for example use the pattern */test/* + +EXCLUDE_PATTERNS = + +# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names +# (namespaces, classes, functions, etc.) that should be excluded from the +# output. The symbol name can be a fully qualified name, a word, or if the +# wildcard * is used, a substring. Examples: ANamespace, AClass, +# AClass::ANamespace, ANamespace::*Test +# +# Note that the wildcards are matched against the file with absolute path, so to +# exclude all test directories use the pattern */test/* + +EXCLUDE_SYMBOLS = + +# The EXAMPLE_PATH tag can be used to specify one or more files or directories +# that contain example code fragments that are included (see the \include +# command). + +EXAMPLE_PATH = + +# If the value of the EXAMPLE_PATH tag contains directories, you can use the +# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and +# *.h) to filter out the source-files in the directories. If left blank all +# files are included. + +EXAMPLE_PATTERNS = * + +# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be +# searched for input files to be used with the \include or \dontinclude commands +# irrespective of the value of the RECURSIVE tag. +# The default value is: NO. + +EXAMPLE_RECURSIVE = NO + +# The IMAGE_PATH tag can be used to specify one or more files or directories +# that contain images that are to be included in the documentation (see the +# \image command). + +IMAGE_PATH = + +# The INPUT_FILTER tag can be used to specify a program that doxygen should +# invoke to filter for each input file. Doxygen will invoke the filter program +# by executing (via popen()) the command: +# +# +# +# where is the value of the INPUT_FILTER tag, and is the +# name of an input file. Doxygen will then use the output that the filter +# program writes to standard output. If FILTER_PATTERNS is specified, this tag +# will be ignored. +# +# Note that the filter must not add or remove lines; it is applied before the +# code is scanned, but not when the output code is generated. If lines are added +# or removed, the anchors will not be placed correctly. + +INPUT_FILTER = + +# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern +# basis. Doxygen will compare the file name with each pattern and apply the +# filter if there is a match. The filters are a list of the form: pattern=filter +# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how +# filters are used. If the FILTER_PATTERNS tag is empty or if none of the +# patterns match the file name, INPUT_FILTER is applied. + +FILTER_PATTERNS = + +# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using +# INPUT_FILTER) will also be used to filter the input files that are used for +# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES). +# The default value is: NO. + +FILTER_SOURCE_FILES = NO + +# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file +# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and +# it is also possible to disable source filtering for a specific pattern using +# *.ext= (so without naming a filter). +# This tag requires that the tag FILTER_SOURCE_FILES is set to YES. + +FILTER_SOURCE_PATTERNS = + +# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that +# is part of the input, its contents will be placed on the main page +# (index.html). This can be useful if you have a project on for instance GitHub +# and want to reuse the introduction page also for the doxygen output. + +USE_MDFILE_AS_MAINPAGE = ../README.md + +#--------------------------------------------------------------------------- +# Configuration options related to source browsing +#--------------------------------------------------------------------------- + +# If the SOURCE_BROWSER tag is set to YES then a list of source files will be +# generated. Documented entities will be cross-referenced with these sources. +# +# Note: To get rid of all source code in the generated output, make sure that +# also VERBATIM_HEADERS is set to NO. +# The default value is: NO. + +SOURCE_BROWSER = NO + +# Setting the INLINE_SOURCES tag to YES will include the body of functions, +# classes and enums directly into the documentation. +# The default value is: NO. + +INLINE_SOURCES = NO + +# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any +# special comment blocks from generated source code fragments. Normal C, C++ and +# Fortran comments will always remain visible. +# The default value is: YES. + +STRIP_CODE_COMMENTS = YES + +# If the REFERENCED_BY_RELATION tag is set to YES then for each documented +# function all documented functions referencing it will be listed. +# The default value is: NO. + +REFERENCED_BY_RELATION = NO + +# If the REFERENCES_RELATION tag is set to YES then for each documented function +# all documented entities called/used by that function will be listed. +# The default value is: NO. + +REFERENCES_RELATION = NO + +# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set +# to YES then the hyperlinks from functions in REFERENCES_RELATION and +# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will +# link to the documentation. +# The default value is: YES. + +REFERENCES_LINK_SOURCE = YES + +# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the +# source code will show a tooltip with additional information such as prototype, +# brief description and links to the definition and documentation. Since this +# will make the HTML file larger and loading of large files a bit slower, you +# can opt to disable this feature. +# The default value is: YES. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +SOURCE_TOOLTIPS = YES + +# If the USE_HTAGS tag is set to YES then the references to source code will +# point to the HTML generated by the htags(1) tool instead of doxygen built-in +# source browser. The htags tool is part of GNU's global source tagging system +# (see http://www.gnu.org/software/global/global.html). You will need version +# 4.8.6 or higher. +# +# To use it do the following: +# - Install the latest version of global +# - Enable SOURCE_BROWSER and USE_HTAGS in the config file +# - Make sure the INPUT points to the root of the source tree +# - Run doxygen as normal +# +# Doxygen will invoke htags (and that will in turn invoke gtags), so these +# tools must be available from the command line (i.e. in the search path). +# +# The result: instead of the source browser generated by doxygen, the links to +# source code will now point to the output of htags. +# The default value is: NO. +# This tag requires that the tag SOURCE_BROWSER is set to YES. + +USE_HTAGS = NO + +# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a +# verbatim copy of the header file for each class for which an include is +# specified. Set to NO to disable this. +# See also: Section \class. +# The default value is: YES. + +VERBATIM_HEADERS = YES + +# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the +# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the +# cost of reduced performance. This can be particularly helpful with template +# rich C++ code for which doxygen's built-in parser lacks the necessary type +# information. +# Note: The availability of this option depends on whether or not doxygen was +# compiled with the --with-libclang option. +# The default value is: NO. + +CLANG_ASSISTED_PARSING = NO + +# If clang assisted parsing is enabled you can provide the compiler with command +# line options that you would normally use when invoking the compiler. Note that +# the include paths will already be set by doxygen for the files and directories +# specified with INPUT and INCLUDE_PATH. +# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES. + +CLANG_OPTIONS = + +#--------------------------------------------------------------------------- +# Configuration options related to the alphabetical class index +#--------------------------------------------------------------------------- + +# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all +# compounds will be generated. Enable this if the project contains a lot of +# classes, structs, unions or interfaces. +# The default value is: YES. + +ALPHABETICAL_INDEX = YES + +# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in +# which the alphabetical index list will be split. +# Minimum value: 1, maximum value: 20, default value: 5. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +COLS_IN_ALPHA_INDEX = 5 + +# In case all classes in a project start with a common prefix, all classes will +# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag +# can be used to specify a prefix (or a list of prefixes) that should be ignored +# while generating the index headers. +# This tag requires that the tag ALPHABETICAL_INDEX is set to YES. + +IGNORE_PREFIX = + +#--------------------------------------------------------------------------- +# Configuration options related to the HTML output +#--------------------------------------------------------------------------- + +# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output +# The default value is: YES. + +GENERATE_HTML = YES + +# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a +# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of +# it. +# The default directory is: html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_OUTPUT = html + +# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each +# generated HTML page (for example: .htm, .php, .asp). +# The default value is: .html. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FILE_EXTENSION = .html + +# The HTML_HEADER tag can be used to specify a user-defined HTML header file for +# each generated HTML page. If the tag is left blank doxygen will generate a +# standard header. +# +# To get valid HTML the header file that includes any scripts and style sheets +# that doxygen needs, which is dependent on the configuration options used (e.g. +# the setting GENERATE_TREEVIEW). It is highly recommended to start with a +# default header using +# doxygen -w html new_header.html new_footer.html new_stylesheet.css +# YourConfigFile +# and then modify the file new_header.html. See also section "Doxygen usage" +# for information on how to generate the default header that doxygen normally +# uses. +# Note: The header is subject to change so you typically have to regenerate the +# default header when upgrading to a newer version of doxygen. For a description +# of the possible markers and block names see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_HEADER = + +# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each +# generated HTML page. If the tag is left blank doxygen will generate a standard +# footer. See HTML_HEADER for more information on how to generate a default +# footer and what special commands can be used inside the footer. See also +# section "Doxygen usage" for information on how to generate the default footer +# that doxygen normally uses. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_FOOTER = + +# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style +# sheet that is used by each HTML page. It can be used to fine-tune the look of +# the HTML output. If left blank doxygen will generate a default style sheet. +# See also section "Doxygen usage" for information on how to generate the style +# sheet that doxygen normally uses. +# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as +# it is more robust and this tag (HTML_STYLESHEET) will in the future become +# obsolete. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_STYLESHEET = + +# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined +# cascading style sheets that are included after the standard style sheets +# created by doxygen. Using this option one can overrule certain style aspects. +# This is preferred over using HTML_STYLESHEET since it does not replace the +# standard style sheet and is therefore more robust against future updates. +# Doxygen will copy the style sheet files to the output directory. +# Note: The order of the extra style sheet files is of importance (e.g. the last +# style sheet in the list overrules the setting of the previous ones in the +# list). For an example see the documentation. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_STYLESHEET = + +# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or +# other source files which should be copied to the HTML output directory. Note +# that these files will be copied to the base HTML output directory. Use the +# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these +# files. In the HTML_STYLESHEET file, use the file name only. Also note that the +# files will be copied as-is; there are no commands or markers available. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_EXTRA_FILES = + +# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen +# will adjust the colors in the style sheet and background images according to +# this color. Hue is specified as an angle on a colorwheel, see +# http://en.wikipedia.org/wiki/Hue for more information. For instance the value +# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300 +# purple, and 360 is red again. +# Minimum value: 0, maximum value: 359, default value: 220. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_HUE = 220 + +# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors +# in the HTML output. For a value of 0 the output will use grayscales only. A +# value of 255 will produce the most vivid colors. +# Minimum value: 0, maximum value: 255, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_SAT = 100 + +# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the +# luminance component of the colors in the HTML output. Values below 100 +# gradually make the output lighter, whereas values above 100 make the output +# darker. The value divided by 100 is the actual gamma applied, so 80 represents +# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not +# change the gamma. +# Minimum value: 40, maximum value: 240, default value: 80. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_COLORSTYLE_GAMMA = 80 + +# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML +# page will contain the date and time when the page was generated. Setting this +# to YES can help to show when doxygen was last run and thus if the +# documentation is up to date. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_TIMESTAMP = NO + +# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML +# documentation will contain sections that can be hidden and shown after the +# page has loaded. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_DYNAMIC_SECTIONS = NO + +# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries +# shown in the various tree structured indices initially; the user can expand +# and collapse entries dynamically later on. Doxygen will expand the tree to +# such a level that at most the specified number of entries are visible (unless +# a fully collapsed tree already exceeds this amount). So setting the number of +# entries 1 will produce a full collapsed tree by default. 0 is a special value +# representing an infinite number of entries and will result in a full expanded +# tree by default. +# Minimum value: 0, maximum value: 9999, default value: 100. +# This tag requires that the tag GENERATE_HTML is set to YES. + +HTML_INDEX_NUM_ENTRIES = 100 + +# If the GENERATE_DOCSET tag is set to YES, additional index files will be +# generated that can be used as input for Apple's Xcode 3 integrated development +# environment (see: http://developer.apple.com/tools/xcode/), introduced with +# OSX 10.5 (Leopard). To create a documentation set, doxygen will generate a +# Makefile in the HTML output directory. Running make will produce the docset in +# that directory and running make install will install the docset in +# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at +# startup. See http://developer.apple.com/tools/creatingdocsetswithdoxygen.html +# for more information. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_DOCSET = NO + +# This tag determines the name of the docset feed. A documentation feed provides +# an umbrella under which multiple documentation sets from a single provider +# (such as a company or product suite) can be grouped. +# The default value is: Doxygen generated docs. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_FEEDNAME = "Doxygen generated docs" + +# This tag specifies a string that should uniquely identify the documentation +# set bundle. This should be a reverse domain-name style string, e.g. +# com.mycompany.MyDocSet. Doxygen will append .docset to the name. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_BUNDLE_ID = org.doxygen.Project + +# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify +# the documentation publisher. This should be a reverse domain-name style +# string, e.g. com.mycompany.MyDocSet.documentation. +# The default value is: org.doxygen.Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_ID = org.doxygen.Publisher + +# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher. +# The default value is: Publisher. +# This tag requires that the tag GENERATE_DOCSET is set to YES. + +DOCSET_PUBLISHER_NAME = Publisher + +# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three +# additional HTML index files: index.hhp, index.hhc, and index.hhk. The +# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop +# (see: http://www.microsoft.com/en-us/download/details.aspx?id=21138) on +# Windows. +# +# The HTML Help Workshop contains a compiler that can convert all HTML output +# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML +# files are now used as the Windows 98 help format, and will replace the old +# Windows help format (.hlp) on all Windows platforms in the future. Compressed +# HTML files also contain an index, a table of contents, and you can search for +# words in the documentation. The HTML workshop also contains a viewer for +# compressed HTML files. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_HTMLHELP = NO + +# The CHM_FILE tag can be used to specify the file name of the resulting .chm +# file. You can add a path in front of the file if the result should not be +# written to the html output directory. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_FILE = + +# The HHC_LOCATION tag can be used to specify the location (absolute path +# including file name) of the HTML help compiler (hhc.exe). If non-empty, +# doxygen will try to run the HTML help compiler on the generated index.hhp. +# The file has to be specified with full path. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +HHC_LOCATION = + +# The GENERATE_CHI flag controls if a separate .chi index file is generated +# (YES) or that it should be included in the master .chm file (NO). +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +GENERATE_CHI = NO + +# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc) +# and project file content. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +CHM_INDEX_ENCODING = + +# The BINARY_TOC flag controls whether a binary table of contents is generated +# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it +# enables the Previous and Next buttons. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +BINARY_TOC = NO + +# The TOC_EXPAND flag can be set to YES to add extra items for group members to +# the table of contents of the HTML help documentation and to the tree view. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTMLHELP is set to YES. + +TOC_EXPAND = NO + +# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and +# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that +# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help +# (.qch) of the generated HTML documentation. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_QHP = NO + +# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify +# the file name of the resulting .qch file. The path specified is relative to +# the HTML output folder. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QCH_FILE = + +# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help +# Project output. For more information please see Qt Help Project / Namespace +# (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#namespace). +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_NAMESPACE = org.doxygen.Project + +# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt +# Help Project output. For more information please see Qt Help Project / Virtual +# Folders (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#virtual- +# folders). +# The default value is: doc. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_VIRTUAL_FOLDER = doc + +# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom +# filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_NAME = + +# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the +# custom filter to add. For more information please see Qt Help Project / Custom +# Filters (see: http://qt-project.org/doc/qt-4.8/qthelpproject.html#custom- +# filters). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_CUST_FILTER_ATTRS = + +# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this +# project's filter section matches. Qt Help Project / Filter Attributes (see: +# http://qt-project.org/doc/qt-4.8/qthelpproject.html#filter-attributes). +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHP_SECT_FILTER_ATTRS = + +# The QHG_LOCATION tag can be used to specify the location of Qt's +# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the +# generated .qhp file. +# This tag requires that the tag GENERATE_QHP is set to YES. + +QHG_LOCATION = + +# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be +# generated, together with the HTML files, they form an Eclipse help plugin. To +# install this plugin and make it available under the help contents menu in +# Eclipse, the contents of the directory containing the HTML and XML files needs +# to be copied into the plugins directory of eclipse. The name of the directory +# within the plugins directory should be the same as the ECLIPSE_DOC_ID value. +# After copying Eclipse needs to be restarted before the help appears. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_ECLIPSEHELP = NO + +# A unique identifier for the Eclipse help plugin. When installing the plugin +# the directory name containing the HTML and XML files should also have this +# name. Each documentation set should have its own identifier. +# The default value is: org.doxygen.Project. +# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES. + +ECLIPSE_DOC_ID = org.doxygen.Project + +# If you want full control over the layout of the generated HTML pages it might +# be necessary to disable the index and replace it with your own. The +# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top +# of each HTML page. A value of NO enables the index and the value YES disables +# it. Since the tabs in the index contain the same information as the navigation +# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +DISABLE_INDEX = NO + +# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index +# structure should be generated to display hierarchical information. If the tag +# value is set to YES, a side panel will be generated containing a tree-like +# index structure (just like the one that is generated for HTML Help). For this +# to work a browser that supports JavaScript, DHTML, CSS and frames is required +# (i.e. any modern browser). Windows users are probably better off using the +# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can +# further fine-tune the look of the index. As an example, the default style +# sheet generated by doxygen has an example that shows how to put an image at +# the root of the tree instead of the PROJECT_NAME. Since the tree basically has +# the same information as the tab index, you could consider setting +# DISABLE_INDEX to YES when enabling this option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +GENERATE_TREEVIEW = NO + +# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that +# doxygen will group on one line in the generated HTML documentation. +# +# Note that a value of 0 will completely suppress the enum values from appearing +# in the overview section. +# Minimum value: 0, maximum value: 20, default value: 4. +# This tag requires that the tag GENERATE_HTML is set to YES. + +ENUM_VALUES_PER_LINE = 1 + +# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used +# to set the initial width (in pixels) of the frame in which the tree is shown. +# Minimum value: 0, maximum value: 1500, default value: 250. +# This tag requires that the tag GENERATE_HTML is set to YES. + +TREEVIEW_WIDTH = 250 + +# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to +# external symbols imported via tag files in a separate window. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +EXT_LINKS_IN_WINDOW = NO + +# Use this tag to change the font size of LaTeX formulas included as images in +# the HTML documentation. When you change the font size after a successful +# doxygen run you need to manually remove any form_*.png images from the HTML +# output directory to force them to be regenerated. +# Minimum value: 8, maximum value: 50, default value: 10. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_FONTSIZE = 10 + +# Use the FORMULA_TRANPARENT tag to determine whether or not the images +# generated for formulas are transparent PNGs. Transparent PNGs are not +# supported properly for IE 6.0, but are supported on all modern browsers. +# +# Note that when changing this option you need to delete any form_*.png files in +# the HTML output directory before the changes have effect. +# The default value is: YES. +# This tag requires that the tag GENERATE_HTML is set to YES. + +FORMULA_TRANSPARENT = YES + +# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see +# http://www.mathjax.org) which uses client side Javascript for the rendering +# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX +# installed or if you want to formulas look prettier in the HTML output. When +# enabled you may also need to install MathJax separately and configure the path +# to it using the MATHJAX_RELPATH option. +# The default value is: NO. +# This tag requires that the tag GENERATE_HTML is set to YES. + +USE_MATHJAX = YES + +# When MathJax is enabled you can set the default output format to be used for +# the MathJax output. See the MathJax site (see: +# http://docs.mathjax.org/en/latest/output.html) for more details. +# Possible values are: HTML-CSS (which is slower, but has the best +# compatibility), NativeMML (i.e. MathML) and SVG. +# The default value is: HTML-CSS. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_FORMAT = HTML-CSS + +# When MathJax is enabled you need to specify the location relative to the HTML +# output directory using the MATHJAX_RELPATH option. The destination directory +# should contain the MathJax.js script. For instance, if the mathjax directory +# is located at the same level as the HTML output directory, then +# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax +# Content Delivery Network so you can quickly see the result without installing +# MathJax. However, it is strongly recommended to install a local copy of +# MathJax from http://www.mathjax.org before deployment. +# The default value is: http://cdn.mathjax.org/mathjax/latest. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_RELPATH = http://cdn.mathjax.org/mathjax/latest + +# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax +# extension names that should be enabled during MathJax rendering. For example +# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_EXTENSIONS = + +# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces +# of code that will be used on startup of the MathJax code. See the MathJax site +# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an +# example see the documentation. +# This tag requires that the tag USE_MATHJAX is set to YES. + +MATHJAX_CODEFILE = + +# When the SEARCHENGINE tag is enabled doxygen will generate a search box for +# the HTML output. The underlying search engine uses javascript and DHTML and +# should work on any modern browser. Note that when using HTML help +# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET) +# there is already a search function so this one should typically be disabled. +# For large projects the javascript based search engine can be slow, then +# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to +# search using the keyboard; to jump to the search box use + S +# (what the is depends on the OS and browser, but it is typically +# , /